Unverified Commit 30aa9c33 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #758 from EleutherAI/add_triviaqa

[Refactor] Add triviaqa
parents 0629a8bf 4b0ab122
...@@ -48,7 +48,9 @@ class Sampler: ...@@ -48,7 +48,9 @@ class Sampler:
) )
+ self.target_delimiter + self.target_delimiter
+ ( + (
self.doc_to_target(doc) self.doc_to_target(doc)[0]
if type(self.doc_to_target(doc)) is list
else self.doc_to_target(doc)
if ( if (
self.config.doc_to_choice is None self.config.doc_to_choice is None
or type(self.doc_to_target(doc)) is str or type(self.doc_to_target(doc)) is str
......
...@@ -771,7 +771,7 @@ class ConfigurableTask(Task): ...@@ -771,7 +771,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str]: def doc_to_target(self, doc: dict) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
...@@ -790,8 +790,12 @@ class ConfigurableTask(Task): ...@@ -790,8 +790,12 @@ class ConfigurableTask(Task):
target_string = utils.apply_template(doc_to_target, doc) target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit(): if target_string.isdigit():
return ast.literal_eval(target_string) return ast.literal_eval(target_string)
elif (target_string[0] == "[") and (target_string[-1] == "]"):
return ast.literal_eval(target_string)
else: else:
return target_string return target_string
elif type(doc_to_target) == list:
return doc_to_target
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) return doc_to_target(doc)
# Used when applying a Promptsource template # Used when applying a Promptsource template
...@@ -1019,20 +1023,20 @@ class ConfigurableTask(Task): ...@@ -1019,20 +1023,20 @@ class ConfigurableTask(Task):
res = res[key] res = res[key]
scores.append(res) scores.append(res)
if any(scores): if any(scores):
result = 1.0 result_score = 1.0
else: else:
result = 0.0 result_score = 0.0
else: else:
result = self._metric_fn_list[key]( result_score = self._metric_fn_list[key](
references=[gold], references=[gold],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[key], **self._metric_fn_kwargs[key],
) )
if isinstance(result, dict): if isinstance(result_score, dict):
result_dict.update(result) result_dict.update(result_score)
else: else:
result_dict[key] = result result_dict[key] = result_score
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
......
...@@ -8,6 +8,7 @@ FILTER_REGISTRY = { ...@@ -8,6 +8,7 @@ FILTER_REGISTRY = {
"regex": extraction.RegexFilter, "regex": extraction.RegexFilter,
"majority_vote": selection.MajorityVoteFilter, "majority_vote": selection.MajorityVoteFilter,
"take_first_k": selection.TakeKFilter, "take_first_k": selection.TakeKFilter,
"remove_whitespace": extraction.WhitespaceFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference. # or should implement different filters for different ways of handling a reward model's inference.
......
...@@ -36,3 +36,26 @@ class RegexFilter(Filter): ...@@ -36,3 +36,26 @@ class RegexFilter(Filter):
# print(filtered_resps) # print(filtered_resps)
return filtered_resps return filtered_resps
class WhitespaceFilter(Filter):
""" """
def __init__(self):
pass
def apply(self, resps):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
resp = resp[1:]
filtered_resp.append(resp)
return filtered_resp
filtered_resps = [filter_set(resp) for resp in resps]
return filtered_resps
...@@ -18,7 +18,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -18,7 +18,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] SciQ - [x] SciQ
- [ ] QASPER - [ ] QASPER
- [x] QA4MRE - [x] QA4MRE
- [ ] TriviaQA (Lintang) - [x] TriviaQA
- [x] AI2 ARC - [x] AI2 ARC
- [x] LogiQA - [x] LogiQA
- [x] HellaSwag - [x] HellaSwag
......
# Trivia QA
### Paper
Title: `TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension`
Abstract: https://arxiv.org/abs/1705.03551
TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/
### Citation
```
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
month = {July},
year = {2017},
address = {Vancouver, Canada},
publisher = {Association for Computational Linguistics},
}
```
### Subtasks
List or describe tasks defined in this folder, and their names here:
* `triviaqa`: `Generate and answer based on the question.`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
task: triviaqa
dataset_path: trivia_qa
dataset_name: rc.nocontext
output_type: greedy_until
training_split: train
validation_split: validation
doc_to_text: "Question: {{question}}?\nAnswer:"
doc_to_target: "{{answer.aliases}}"
should_decontaminate: true
doc_to_decontamination_query: question
generation_kwargs:
until:
- "\n"
- "."
- ","
do_sample: false
temperature: 0.0
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
target_delimiter: " "
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment