Commit e8702f15 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

Merge branch 'big-refactor' into unscramble+toxigen

parents dcb16263 45737a38
...@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = {
], ],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"], "multiple_choice": ["acc", "acc_norm"],
"winograd_schema": ["acc"],
"greedy_until": ["exact_match"], "greedy_until": ["exact_match"],
} }
......
...@@ -43,6 +43,7 @@ ALL_OUTPUT_TYPES = [ ...@@ -43,6 +43,7 @@ ALL_OUTPUT_TYPES = [
"multiple_choice", "multiple_choice",
"loglikelihood_rolling", "loglikelihood_rolling",
"greedy_until", "greedy_until",
"winograd_schema"
] ]
...@@ -73,9 +74,10 @@ class TaskConfig(dict): ...@@ -73,9 +74,10 @@ class TaskConfig(dict):
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
# runtime configuration options # runtime configuration options
num_fewshot: int = 0 num_fewshot: int = 0
batch_size: int = 1
# scoring options # scoring options
metric_list: str = None metric_list: str = None
gold_alias: Union[Callable, str] = None
create_choices: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
repeats: int = 1 repeats: int = 1
...@@ -297,6 +299,18 @@ class Task(abc.ABC): ...@@ -297,6 +299,18 @@ class Task(abc.ABC):
The processed version of the specified `doc`. The processed version of the specified `doc`.
""" """
return doc return doc
def create_choices(self, doc):
if self._config.create_choices is None:
return ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
elif type(self._config.create_choices) == str:
return utils.apply_template(self._config.create_choices, doc)
else:
return self._config.create_choices(doc)
@property @property
def instances(self): def instances(self):
...@@ -468,7 +482,7 @@ class Task(abc.ABC): ...@@ -468,7 +482,7 @@ class Task(abc.ABC):
The fewshot context. The fewshot context.
""" """
# TODO: this should only return the overrides applied to a non-YAML task's configuration. # TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (batch size, num_fewshot) # (num_fewshot)
return self._config.to_dict() return self._config.to_dict()
...@@ -728,11 +742,8 @@ class ConfigurableTask(Task): ...@@ -728,11 +742,8 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list. # we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this? # TODO: any cleaner way to do this?
choices = ast.literal_eval( choices = self.create_choices(doc)
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
...@@ -768,6 +779,26 @@ class ConfigurableTask(Task): ...@@ -768,6 +779,26 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.generation_kwargs) arguments = (ctx, self._config.generation_kwargs)
elif self.OUTPUT_TYPE == "winograd_schema":
# similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation
contexts = self.create_choices(doc)
choice = self.doc_to_target(doc)
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(context, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, context in enumerate(contexts)
]
return request_list
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
) )
...@@ -816,11 +847,7 @@ class ConfigurableTask(Task): ...@@ -816,11 +847,7 @@ class ConfigurableTask(Task):
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = self.create_choices(doc)
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys() and "acc_mutual_info" in self._metric_fn_list.keys()
...@@ -857,6 +884,21 @@ class ConfigurableTask(Task): ...@@ -857,6 +884,21 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "winograd_schema":
lls, is_greedy = zip(*results)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
acc = 1.0 if np.argmax(lls) == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None: if self._config.gold_alias is not None:
...@@ -875,7 +917,7 @@ class ConfigurableTask(Task): ...@@ -875,7 +917,7 @@ class ConfigurableTask(Task):
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'", "'loglikelihood', 'loglikelihood_rolling', 'greedy_until', 'multiple_choice' or 'winograd_schema' ",
) )
return result_dict return result_dict
......
...@@ -181,8 +181,6 @@ def evaluate( ...@@ -181,8 +181,6 @@ def evaluate(
samples = collections.defaultdict(list) samples = collections.defaultdict(list)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
# docs = {}
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
...@@ -215,7 +213,7 @@ def evaluate( ...@@ -215,7 +213,7 @@ def evaluate(
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = ( reqtype = (
"loglikelihood" "loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice" if (task.OUTPUT_TYPE == "multiple_choice" or task.OUTPUT_TYPE == "winograd_schema")
else task.OUTPUT_TYPE else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py ) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances) requests[reqtype].extend(task.instances)
...@@ -274,7 +272,6 @@ def evaluate( ...@@ -274,7 +272,6 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
) )
) )
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
......
...@@ -9,28 +9,34 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -9,28 +9,34 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] DROP - [ ] DROP
- [x] ~~Lambada~~ - [x] ~~Lambada~~
- [x] Lambada (Cloze variants) - [x] Lambada (Cloze variants)
- [ ] Lambada (Multilingual) - [x] ~~Lambada (Multilingual)~~
- [x] Wikitext - [x] Wikitext
- [x] PiQA - [x] PiQA
- [ ] PROST (WIP) - [ ] PROST (WIP)
- [ ] MCTACO - [ ] MCTACO
- [ ] Pubmed QA (WIP) - [x] Pubmed QA
- [x] SciQ - [x] SciQ
- [ ] QASPER - [ ] QASPER
- [ ] QA4MRE (WIP) - [ ] QA4MRE (WIP)
- [ ] TriviaQA - [ ] TriviaQA
- [x] AI2 ARC - [x] AI2 ARC
- [ ] LogiQA
- [x] HellaSwag
- [x] SWAG
- [x] OpenBookQA
- [ ] SQuADv2
- [x] RACE
- [ ] LogiQA (WIP) - [ ] LogiQA (WIP)
- [x] HellaSwag - [x] HellaSwag
- [ ] SWAG (WIP) - [ ] SWAG (WIP)
- [x] OpenBookQA - [x] OpenBookQA
- [ ] SQuADv2 (WIP) - [ ] SQuADv2 (WIP)
- [ ] RACE (WIP) - [ ] RACE (WIP)
- [ ] HeadQA - [ ] HeadQA (WIP)
- [ ] MathQA - [ ] MathQA
- [ ] WebQs - [ ] WebQs
- [ ] WSC273 - [ ] WSC273
- [ ] Winogrande (WIP) - [x] Winogrande
- [x] ANLI - [x] ANLI
- [ ] Hendrycks Ethics - [ ] Hendrycks Ethics
- [ ] TruthfulQA - [ ] TruthfulQA
...@@ -38,7 +44,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -38,7 +44,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] Hendrycks Math (WIP) - [ ] Hendrycks Math (WIP)
- [ ] Asdiv - [ ] Asdiv
- [ ] GSM8k - [ ] GSM8k
- [ ] Arithmetic (WIP) - [x] Arithmetic
- [ ] MMMLU - [ ] MMMLU
- [ ] Translation (WMT) suite - [ ] Translation (WMT) suite
- [x] Unscramble - [x] Unscramble
......
group:
- arithmetic
task: arithmetic_1dc
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_1dc
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2dm
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2dm
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_3da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_3da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_3ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_3ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_4da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_4da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_4ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_4ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_5da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_5da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_5ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_5ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
# LAMBADA
### Paper
The LAMBADA dataset: Word prediction requiring a broad discourse context
https://arxiv.org/pdf/1606.06031.pdf
LAMBADA is a dataset to evaluate the capabilities of computational models for text
understanding by means of a word prediction task. LAMBADA is a collection of narrative
passages sharing the characteristic that human subjects are able to guess their last
word if they are exposed to the whole passage, but not if they only see the last
sentence preceding the target word. To succeed on LAMBADA, computational models
cannot simply rely on local context, but must be able to keep track of information
in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
### Citation
@misc{
author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
title={The LAMBADA dataset},
DOI={10.5281/zenodo.2630551},
publisher={Zenodo},
year={2016},
month={Aug}
}
### Subtasks
* `lambada_mt_{en, fr, de, it, es}`: Machine-translated versions of OpenAI's Lambada variant.
### Checklist
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
(This task is novel to the Evaluation Harness, and has been checked against v0.3.0 of the harness.)
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
include: lambada_mt_en.yaml
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_de
dataset_name: de
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_en
dataset_path: EleutherAI/lambada_openai
dataset_name: en
output_type: loglikelihood
test_split: test
template_aliases: ""
doc_to_text: "{{text.split(' ')[:-1]|join(' ')}}"
doc_to_target: "{{' '+text.split(' ')[-1]}}"
should_decontaminate: true
doc_to_decontamination_query: "{{text}}"
metric_list:
- metric: perplexity
aggregation: perplexity
higher_is_better: false
- metric: acc
aggregation: mean
higher_is_better: true
include: lambada_mt_en.yaml
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_es
dataset_name: es
include: lambada_mt_en.yaml
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_fr
dataset_name: fr
include: lambada_mt_en.yaml
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_it
dataset_name: it
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment