Unverified Commit 287f7efc authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge branch 'big-refactor' into headqa

parents 3ab3c388 884ba785
......@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = {
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"winograd_schema": ["acc"],
"greedy_until": ["exact_match"],
}
......
......@@ -43,6 +43,7 @@ ALL_OUTPUT_TYPES = [
"multiple_choice",
"loglikelihood_rolling",
"greedy_until",
"winograd_schema"
]
......@@ -73,9 +74,10 @@ class TaskConfig(dict):
fewshot_delimiter: str = "\n\n"
# runtime configuration options
num_fewshot: int = 0
batch_size: int = 1
# scoring options
metric_list: str = None
gold_alias: Union[Callable, str] = None
create_choices: Union[Callable, str] = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
repeats: int = 1
......@@ -99,13 +101,29 @@ class TaskConfig(dict):
if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs:
assert (
self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!"
elif self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
if self.generation_kwargs is not None:
if self.output_type != "greedy_until":
eval_logger.warning(
"passed `generation_kwargs`, but not using a generation request type!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter],
"do_sample": False,
"temperature": 0.0,
}
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
......@@ -297,6 +315,18 @@ class Task(abc.ABC):
The processed version of the specified `doc`.
"""
return doc
def create_choices(self, doc):
if self._config.create_choices is None:
return ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
elif type(self._config.create_choices) == str:
return utils.apply_template(self._config.create_choices, doc)
else:
return self._config.create_choices(doc)
@property
def instances(self):
......@@ -468,7 +498,7 @@ class Task(abc.ABC):
The fewshot context.
"""
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (batch size, num_fewshot)
# (num_fewshot)
return self._config.to_dict()
......@@ -728,11 +758,8 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this?
choices = ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
choices = self.create_choices(doc)
request_list = [
Instance(
request_type="loglikelihood",
......@@ -768,6 +795,26 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.generation_kwargs)
elif self.OUTPUT_TYPE == "winograd_schema":
# similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation
contexts = self.create_choices(doc)
choice = self.doc_to_target(doc)
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(context, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, context in enumerate(contexts)
]
return request_list
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
)
......@@ -816,11 +863,7 @@ class ConfigurableTask(Task):
gold = int(self.doc_to_target(doc))
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
choices = self.create_choices(doc)
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
......@@ -857,6 +900,21 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "winograd_schema":
lls, is_greedy = zip(*results)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
acc = 1.0 if np.argmax(lls) == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None:
......@@ -875,7 +933,7 @@ class ConfigurableTask(Task):
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', 'multiple_choice' or 'winograd_schema' ",
)
return result_dict
......
......@@ -181,8 +181,6 @@ def evaluate(
samples = collections.defaultdict(list)
requests = collections.defaultdict(list)
# docs = {}
# get lists of each type of request
for task_name, task in task_dict.items():
versions[task_name] = task.VERSION
......@@ -202,7 +200,7 @@ def evaluate(
# aggregate Instances by LM method requested to get output.
reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
if (task.OUTPUT_TYPE == "multiple_choice" or task.OUTPUT_TYPE == "winograd_schema")
else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
......@@ -261,7 +259,6 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
)
)
for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
......
......@@ -9,42 +9,47 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] DROP
- [x] ~~Lambada~~
- [x] Lambada (Cloze variants)
- [ ] Lambada (Multilingual)
- [x] ~~Lambada (Multilingual)~~
- [x] Wikitext
- [x] PiQA
- [ ] PROST (WIP)
- [x] PROST
- [ ] MCTACO
- [ ] Pubmed QA (WIP)
- [x] Pubmed QA
- [x] SciQ
- [ ] QASPER
- [ ] QA4MRE (WIP)
- [x] QA4MRE
- [ ] TriviaQA
- [x] AI2 ARC
- [ ] LogiQA (WIP)
- [x] HellaSwag
- [ ] SWAG (WIP)
- [x] SWAG
- [x] OpenBookQA
- [x] RACE
- [ ] LogiQA (WIP)
- [x] HellaSwag
- [x] SWAG
- [x] OpenBookQA
- [ ] SQuADv2 (WIP)
- [ ] RACE (WIP)
- [x] HeadQA
- [ ] MathQA
- [x] RACE
- [x] HeadQA (WIP)
- [ ] MathQA (WIP)
- [ ] WebQs
- [ ] WSC273
- [ ] Winogrande (WIP)
- [x] Winogrande
- [x] ANLI
- [ ] Hendrycks Ethics
- [ ] TruthfulQA
- [ ] MuTual
- [ ] Hendrycks Math (WIP)
- [ ] Asdiv
- [ ] Asdiv (WIP)
- [ ] GSM8k
- [ ] Arithmetic (WIP)
- [x] Arithmetic
- [ ] MMMLU
- [ ] Translation (WMT) suite
- [ ] Unscramble
- [ ] Unscramble (WIP)
- [x] ~~Pile (perplexity)~~
- [ ] BLiMP
- [ ] ToxiGen
- [ ] ToxiGen (WIP)
- [ ] StoryCloze
- [ ] NaturalQs
- [ ] CrowS-Pairs
......
group:
- arithmetic
task: arithmetic_1dc
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_1dc
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2dm
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2dm
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_3da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_3da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_3ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_3ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_4da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_4da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_4ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_4ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_5da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_5da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_5ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_5ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
# LAMBADA
### Paper
The LAMBADA dataset: Word prediction requiring a broad discourse context
https://arxiv.org/pdf/1606.06031.pdf
LAMBADA is a dataset to evaluate the capabilities of computational models for text
understanding by means of a word prediction task. LAMBADA is a collection of narrative
passages sharing the characteristic that human subjects are able to guess their last
word if they are exposed to the whole passage, but not if they only see the last
sentence preceding the target word. To succeed on LAMBADA, computational models
cannot simply rely on local context, but must be able to keep track of information
in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
### Citation
@misc{
author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
title={The LAMBADA dataset},
DOI={10.5281/zenodo.2630551},
publisher={Zenodo},
year={2016},
month={Aug}
}
### Subtasks
* `lambada_mt_{en, fr, de, it, es}`: Machine-translated versions of OpenAI's Lambada variant.
### Checklist
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
(This task is novel to the Evaluation Harness, and has been checked against v0.3.0 of the harness.)
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
include: lambada_mt_en.yaml
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_de
dataset_name: de
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_en
dataset_path: EleutherAI/lambada_openai
dataset_name: en
output_type: loglikelihood
test_split: test
template_aliases: ""
doc_to_text: "{{text.split(' ')[:-1]|join(' ')}}"
doc_to_target: "{{' '+text.split(' ')[-1]}}"
should_decontaminate: true
doc_to_decontamination_query: "{{text}}"
metric_list:
- metric: perplexity
aggregation: perplexity
higher_is_better: false
- metric: acc
aggregation: mean
higher_is_better: true
include: lambada_mt_en.yaml
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_es
dataset_name: es
include: lambada_mt_en.yaml
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_fr
dataset_name: fr
include: lambada_mt_en.yaml
group:
- lambada_multilingual
- loglikelihood
- perplexity
task: lambada_openai_mt_it
dataset_name: it
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment