Commit 1f8a8c1d authored by jon-tow's avatar jon-tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into remove-dataset

parents b4c0275d b0acb337
...@@ -51,17 +51,34 @@ class QuAC(Task): ...@@ -51,17 +51,34 @@ class QuAC(Task):
raise NotImplementedError("QuAC has no test docs.") raise NotImplementedError("QuAC has no test docs.")
def _process_doc(self, doc): def _process_doc(self, doc):
doc["title"] = doc['title'] + ' - ' + doc['section_title'] doc["title"] = doc["title"] + " - " + doc["section_title"]
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'TITLE: ' + doc['title'] + '\n' + 'PARAGRAPH: ' + doc['paragraph'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: ' return (
"TITLE: "
+ doc["title"]
+ "\n"
+ "PARAGRAPH: "
+ doc["paragraph"]
+ "\n\n"
+ "Q: "
+ doc["question"]
+ "\n\n"
+ "A: "
)
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["paragraph"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc['answer'] return doc["answer"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
...@@ -72,7 +89,7 @@ class QuAC(Task): ...@@ -72,7 +89,7 @@ class QuAC(Task):
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -85,7 +102,7 @@ class QuAC(Task): ...@@ -85,7 +102,7 @@ class QuAC(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def aggregation(self): def aggregation(self):
""" """
...@@ -94,7 +111,7 @@ class QuAC(Task): ...@@ -94,7 +111,7 @@ class QuAC(Task):
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -103,4 +120,4 @@ class QuAC(Task): ...@@ -103,4 +120,4 @@ class QuAC(Task):
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
This diff is collapsed.
...@@ -59,11 +59,19 @@ class SATAnalogies(MultipleChoiceTask): ...@@ -59,11 +59,19 @@ class SATAnalogies(MultipleChoiceTask):
def _process_doc(self, doc): def _process_doc(self, doc):
return { return {
'source': doc['source'], "source": doc["source"],
'query': doc['stem'].split(' ')[:2], "query": doc["stem"].split(" ")[:2],
'choices': ["{} is to {}".format(*c.split(' ')[:2]) for c in doc["choices"]], "choices": [
'gold': ['a', 'b', 'c', 'd', 'e'].index(doc['solution'].strip()), "{} is to {}".format(*c.split(" ")[:2]) for c in doc["choices"]
],
"gold": ["a", "b", "c", "d", "e"].index(doc["solution"].strip()),
} }
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{} is to {} as".format(*doc['query']) return "{} is to {} as".format(*doc["query"])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + "\n" + " ".join(doc["query"])
...@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask): ...@@ -54,10 +54,10 @@ class SciQ(MultipleChoiceTask):
doc["distractor3"], doc["distractor3"],
doc["correct_answer"], doc["correct_answer"],
] ]
src = doc['support'] src = doc["support"]
out_doc = { out_doc = {
"source": src, "source": src,
"query": doc['question'], "query": doc["question"],
"choices": choices, "choices": choices,
"gold": 3, "gold": 3,
} }
...@@ -65,3 +65,9 @@ class SciQ(MultipleChoiceTask): ...@@ -65,3 +65,9 @@ class SciQ(MultipleChoiceTask):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip() return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["source"] + " " + doc["query"]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -43,10 +43,10 @@ class TriviaQA(Task): ...@@ -43,10 +43,10 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return self.dataset['train'] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self.dataset['validation'] return self.dataset["validation"]
def test_docs(self): def test_docs(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -54,8 +54,14 @@ class TriviaQA(Task): ...@@ -54,8 +54,14 @@ class TriviaQA(Task):
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Question: {doc['question']}\nAnswer:" return f"Question: {doc['question']}\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['answer']['value'] return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases): def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list # Optimization: Remove any alias that has a strict prefix elsewhere in the list
...@@ -69,15 +75,13 @@ class TriviaQA(Task): ...@@ -69,15 +75,13 @@ class TriviaQA(Task):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] ret = []
for alias in self._remove_prefixes(doc['answer']['aliases']): for alias in self._remove_prefixes(doc["answer"]["aliases"]):
_, is_prediction = rf.loglikelihood(ctx, " " + alias) _, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction) ret.append(is_prediction)
return ret return ret
def process_results(self, doc, results): def process_results(self, doc, results):
return { return {"acc": float(any(results))}
"acc": float(any(results))
}
def aggregation(self): def aggregation(self):
return { return {
...@@ -85,6 +89,4 @@ class TriviaQA(Task): ...@@ -85,6 +89,4 @@ class TriviaQA(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -90,6 +90,9 @@ class WikiText(PerplexityTask): ...@@ -90,6 +90,9 @@ class WikiText(PerplexityTask):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return wikitext_detokenizer(doc) return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc): def count_words(self, doc):
# count number of words in *original doc before detokenization* # count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment