Unverified Commit 45737a38 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #627 from fattorib/refactor-more-tasks

[Refactor] Add: SWAG,RACE,Arithmetic,Winogrande,PubmedQA
parents 0ba4ae15 02362e6a
...@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = {
], ],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"], "multiple_choice": ["acc", "acc_norm"],
"winograd_schema": ["acc"],
"greedy_until": ["exact_match"], "greedy_until": ["exact_match"],
} }
......
...@@ -43,6 +43,7 @@ ALL_OUTPUT_TYPES = [ ...@@ -43,6 +43,7 @@ ALL_OUTPUT_TYPES = [
"multiple_choice", "multiple_choice",
"loglikelihood_rolling", "loglikelihood_rolling",
"greedy_until", "greedy_until",
"winograd_schema"
] ]
...@@ -75,6 +76,8 @@ class TaskConfig(dict): ...@@ -75,6 +76,8 @@ class TaskConfig(dict):
num_fewshot: int = 0 num_fewshot: int = 0
# scoring options # scoring options
metric_list: str = None metric_list: str = None
gold_alias: Union[Callable, str] = None
create_choices: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
repeats: int = 1 repeats: int = 1
...@@ -297,6 +300,18 @@ class Task(abc.ABC): ...@@ -297,6 +300,18 @@ class Task(abc.ABC):
""" """
return doc return doc
def create_choices(self, doc):
if self._config.create_choices is None:
return ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
elif type(self._config.create_choices) == str:
return utils.apply_template(self._config.create_choices, doc)
else:
return self._config.create_choices(doc)
@property @property
def instances(self): def instances(self):
"""After calling `task.build_all_requests()`, tasks """After calling `task.build_all_requests()`, tasks
...@@ -727,11 +742,8 @@ class ConfigurableTask(Task): ...@@ -727,11 +742,8 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list. # we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this? # TODO: any cleaner way to do this?
choices = ast.literal_eval( choices = self.create_choices(doc)
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
...@@ -767,6 +779,26 @@ class ConfigurableTask(Task): ...@@ -767,6 +779,26 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.generation_kwargs) arguments = (ctx, self._config.generation_kwargs)
elif self.OUTPUT_TYPE == "winograd_schema":
# similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation
contexts = self.create_choices(doc)
choice = self.doc_to_target(doc)
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(context, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, context in enumerate(contexts)
]
return request_list
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
) )
...@@ -815,11 +847,7 @@ class ConfigurableTask(Task): ...@@ -815,11 +847,7 @@ class ConfigurableTask(Task):
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = self.create_choices(doc)
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys() and "acc_mutual_info" in self._metric_fn_list.keys()
...@@ -856,6 +884,21 @@ class ConfigurableTask(Task): ...@@ -856,6 +884,21 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "winograd_schema":
lls, is_greedy = zip(*results)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
acc = 1.0 if np.argmax(lls) == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None: if self._config.gold_alias is not None:
...@@ -874,7 +917,7 @@ class ConfigurableTask(Task): ...@@ -874,7 +917,7 @@ class ConfigurableTask(Task):
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'", "'loglikelihood', 'loglikelihood_rolling', 'greedy_until', 'multiple_choice' or 'winograd_schema' ",
) )
return result_dict return result_dict
......
...@@ -200,7 +200,7 @@ def evaluate( ...@@ -200,7 +200,7 @@ def evaluate(
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = ( reqtype = (
"loglikelihood" "loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice" if (task.OUTPUT_TYPE == "multiple_choice" or task.OUTPUT_TYPE == "winograd_schema")
else task.OUTPUT_TYPE else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py ) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances) requests[reqtype].extend(task.instances)
...@@ -259,7 +259,6 @@ def evaluate( ...@@ -259,7 +259,6 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
) )
) )
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
......
...@@ -14,12 +14,18 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -14,12 +14,18 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] PiQA - [x] PiQA
- [ ] PROST (WIP) - [ ] PROST (WIP)
- [ ] MCTACO - [ ] MCTACO
- [ ] Pubmed QA (WIP) - [x] Pubmed QA
- [x] SciQ - [x] SciQ
- [ ] QASPER - [ ] QASPER
- [ ] QA4MRE (WIP) - [ ] QA4MRE (WIP)
- [ ] TriviaQA - [ ] TriviaQA
- [x] AI2 ARC - [x] AI2 ARC
- [ ] LogiQA
- [x] HellaSwag
- [x] SWAG
- [x] OpenBookQA
- [ ] SQuADv2
- [x] RACE
- [ ] LogiQA (WIP) - [ ] LogiQA (WIP)
- [x] HellaSwag - [x] HellaSwag
- [ ] SWAG (WIP) - [ ] SWAG (WIP)
...@@ -30,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -30,7 +36,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] MathQA - [ ] MathQA
- [ ] WebQs - [ ] WebQs
- [ ] WSC273 - [ ] WSC273
- [ ] Winogrande (WIP) - [x] Winogrande
- [x] ANLI - [x] ANLI
- [ ] Hendrycks Ethics - [ ] Hendrycks Ethics
- [ ] TruthfulQA - [ ] TruthfulQA
...@@ -38,7 +44,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -38,7 +44,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] Hendrycks Math (WIP) - [ ] Hendrycks Math (WIP)
- [ ] Asdiv - [ ] Asdiv
- [ ] GSM8k - [ ] GSM8k
- [ ] Arithmetic (WIP) - [x] Arithmetic
- [ ] MMMLU - [ ] MMMLU
- [ ] Translation (WMT) suite - [ ] Translation (WMT) suite
- [ ] Unscramble - [ ] Unscramble
......
group:
- arithmetic
task: arithmetic_1dc
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_1dc
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2dm
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2dm
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_2ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_3da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_3da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_3ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_3ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_4da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_4da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_4ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_4ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_5da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_5da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group:
- arithmetic
task: arithmetic_5ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_5ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
def doc_to_text(doc):
ctxs = "\n".join(doc["context"]["contexts"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["question"], doc["final_decision"]
)
def doc_to_target(doc):
return " {}".format(doc["final_decision"])
def gold_alias(doc):
dict_to_label = {
'yes': 0,
'no': 1,
'maybe': 2
}
return dict_to_label[doc["final_decision"]]
\ No newline at end of file
group:
- multiple_choice
task: pubmed_qa
dataset_path: pubmed_qa
dataset_name: pqa_labeled
output_type: multiple_choice
training_split: null
validation_split: null
test_split: train
template_aliases: "{% set answer_choices = ['yes', 'no', 'maybe'] %}{% set gold = final_decision %}"
doc_to_text: !function preprocess_pubmedqa.doc_to_text
doc_to_target: !function preprocess_pubmedqa.doc_to_target
gold_alias: !function preprocess_pubmedqa.gold_alias
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
\ No newline at end of file
import ast
def process_ast(string):
return ast.literal_eval(string)
def last_problem(doc):
return process_ast(doc["problems"])[-1]
def get_answer_option(problem):
letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
answer = letter_to_num[problem["answer"]]
return problem["options"][answer]
def create_choices(doc):
problem = last_problem(doc)
choices = [problem["options"][i] for i in range(4)]
return choices
def doc_to_text(doc):
text = "Article: " + doc["article"] + "\n\n"
for problem in process_ast(doc["problems"])[:-1]:
if problem["question"][-6:] == " _ .":
text += (
problem["question"][-5:] + get_answer_option(problem) + "\n"
)
else:
question = "Question: " + problem["question"] + "\n"
answer = "Answer: " + get_answer_option(problem) + "\n"
text += question + answer
text += last_problem(doc)["question"]
return text
def doc_to_target(doc):
letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
answer = letter_to_num[last_problem(doc)["answer"]]
return answer
group:
- multiple_choice
task: race
dataset_path: bfattori/race
dataset_name: high
output_type: multiple_choice
test_split: test
create_choices: !function preprocess_race.create_choices
doc_to_text: !function preprocess_race.doc_to_text
doc_to_target: !function preprocess_race.doc_to_target
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
\ No newline at end of file
group:
- multiple_choice
task: swag
dataset_path: swag
dataset_name: regular
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: null
template_aliases: "{% set answer_choices = [ending0, ending1, ending2, ending3] %}{% set gold = label %}"
doc_to_text: "{{startphrase}}"
doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
\ No newline at end of file
def partial_context(doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
return doc["sentence"][:pronoun_loc] + option
def partial_target(doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def create_choices(doc):
choices = []
for option in [doc["option1"], doc["option2"]]:
partial_ctx = partial_context(doc, option)
choices.append(partial_ctx)
return choices
def gold_alias(doc):
answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc['answer']]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment