Commit 79545adb authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

Merge remote-tracking branch 'upstream/big-refactor' into seq2seq-refactor

parents eb7b9095 761f0087
...@@ -7,7 +7,6 @@ output_type: multiple_choice ...@@ -7,7 +7,6 @@ output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
test_split: test test_split: test
# TODO: we should see how shuffling answer choices affects perf.
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: "{{gold}}" # this will be cast to an int.
......
group:
- super-glue-lm-eval-v1
task: "default"
dataset_path: super_glue
dataset_name: boolq
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
import sklearn
import numpy as np
def cb_multi_fi(items):
preds, golds = zip(*items)
preds = np.array(preds)
golds = np.array(golds)
f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0)
f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1)
f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2)
avg_f1 = np.mean([f11, f12, f13])
return avg_f1
group:
- super-glue-lm-eval-v1
task: "default"
dataset_path: super_glue
dataset_name: cb
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list:
- metric: acc
- metric: f1
aggregation: !function "aggregate.cb_multi_fi"
"""
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension
https://arxiv.org/pdf/1705.03551.pdf
TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
import inspect
# import lm_eval.datasets.triviaqa.triviaqa
import string
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
from lm_eval.api.metrics import mean
_CITATION = """
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
month = {July},
year = {2017},
address = {Vancouver, Canada},
publisher = {Association for Computational Linguistics},
}
"""
@register_task("triviaqa")
class TriviaQA(Task):
VERSION = 1
DATASET_PATH = "trivia_qa" # inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = "unfiltered.nocontext"
OUTPUT_TYPE = "greedy_until"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError()
def doc_to_text(self, doc):
return f"Q: {doc['question']}\nA:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc):
return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(
ctx,
{
"until": ["\n", ".", ","],
"do_sample": False,
},
),
idx=0,
**kwargs,
)
return continuation
def process_results(self, doc, results):
continuation = (
results[0]
.strip()
.lower()
.translate(str.maketrans("", "", string.punctuation))
)
list_of_candidates = [
alias.lower().translate(str.maketrans("", "", string.punctuation))
for alias in self._remove_prefixes(doc["answer"]["aliases"])
]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self):
return {
"em": mean,
}
def higher_is_better(self):
return {"em": True}
...@@ -13,7 +13,7 @@ import re ...@@ -13,7 +13,7 @@ import re
from lm_eval.api.task import PerplexityTask from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group from lm_eval.api.registry import register_task, register_group
_CITATION = """ _CITATION = """
@misc{merity2016pointer, @misc{merity2016pointer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment