Commit 79545adb authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

Merge remote-tracking branch 'upstream/big-refactor' into seq2seq-refactor

parents eb7b9095 761f0087
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -19,4 +19,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
group:
group:
- piqa_yaml_grp
task: piqa_yaml
dataset_path: piqa
......@@ -9,7 +9,7 @@ validation_split: validation
test_split: null
template_aliases: "{% set question = goal %}{% set answer_choices = [sol1, sol2] %}{% set gold = label %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......
group:
group:
- sciq_yaml_grp
task: sciq_yaml
dataset_path: sciq
......@@ -7,10 +7,9 @@ output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
# TODO: we should see how shuffling answer choices affects perf.
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
......@@ -20,4 +19,4 @@ metric_list:
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
\ No newline at end of file
higher_is_better: true
group:
- super-glue-lm-eval-v1
task: "default"
dataset_path: super_glue
dataset_name: boolq
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
import sklearn
import numpy as np
def cb_multi_fi(items):
preds, golds = zip(*items)
preds = np.array(preds)
golds = np.array(golds)
f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0)
f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1)
f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2)
avg_f1 = np.mean([f11, f12, f13])
return avg_f1
group:
- super-glue-lm-eval-v1
task: "default"
dataset_path: super_glue
dataset_name: cb
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list:
- metric: acc
- metric: f1
aggregation: !function "aggregate.cb_multi_fi"
"""
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension
https://arxiv.org/pdf/1705.03551.pdf
TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
import inspect
# import lm_eval.datasets.triviaqa.triviaqa
import string
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
from lm_eval.api.metrics import mean
_CITATION = """
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
month = {July},
year = {2017},
address = {Vancouver, Canada},
publisher = {Association for Computational Linguistics},
}
"""
@register_task("triviaqa")
class TriviaQA(Task):
VERSION = 1
DATASET_PATH = "trivia_qa" # inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = "unfiltered.nocontext"
OUTPUT_TYPE = "greedy_until"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError()
def doc_to_text(self, doc):
return f"Q: {doc['question']}\nA:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc):
return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(
ctx,
{
"until": ["\n", ".", ","],
"do_sample": False,
},
),
idx=0,
**kwargs,
)
return continuation
def process_results(self, doc, results):
continuation = (
results[0]
.strip()
.lower()
.translate(str.maketrans("", "", string.punctuation))
)
list_of_candidates = [
alias.lower().translate(str.maketrans("", "", string.punctuation))
for alias in self._remove_prefixes(doc["answer"]["aliases"])
]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self):
return {
"em": mean,
}
def higher_is_better(self):
return {"em": True}
......@@ -13,7 +13,7 @@ import re
from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group
from lm_eval.api.registry import register_task, register_group
_CITATION = """
@misc{merity2016pointer,
......
......@@ -33,4 +33,4 @@ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-
- [x] Is in Eval-harness v1.0 ?
- [x] Has been checked for regression from v1.0?
- [ ] Has been checked for equivalence with original paper methodology?
- [ ] "Main" checked variant clearly denoted?
\ No newline at end of file
- [ ] "Main" checked variant clearly denoted?
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment