Unverified Commit bc5495d2 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge branch 'master' into wsc273-evaluation

parents b5e86d3f e12d0078
...@@ -17,6 +17,7 @@ from . import lambada ...@@ -17,6 +17,7 @@ from . import lambada
from . import race from . import race
from . import piqa from . import piqa
from . import triviaqa from . import triviaqa
from . import webqs
TASK_REGISTRY = { TASK_REGISTRY = {
...@@ -55,7 +56,7 @@ TASK_REGISTRY = { ...@@ -55,7 +56,7 @@ TASK_REGISTRY = {
# "squad": squad.SQuAD, # not implemented yet # "squad": squad.SQuAD, # not implemented yet
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "webqs": webqs.WebQs, # not implemented yet "webqs": webqs.WebQs,
"wsc273": wsc273.WinogradSchemaChallenge273, "wsc273": wsc273.WinogradSchemaChallenge273,
# "winogrande": winogrande.Winogrande, # not implemented yet # "winogrande": winogrande.Winogrande, # not implemented yet
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
......
from . common import HFTask from . common import HFTask
from lm_eval.base import mean, rf
class WebQs(HFTask): class WebQs(HFTask):
DATASET_PATH = "web_questions" DATASET_PATH = "web_questions"
...@@ -18,7 +19,6 @@ class WebQs(HFTask): ...@@ -18,7 +19,6 @@ class WebQs(HFTask):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
print(doc)
return "Q: " + doc['question'] + '\nA:' return "Q: " + doc['question'] + '\nA:'
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -27,47 +27,36 @@ class WebQs(HFTask): ...@@ -27,47 +27,36 @@ class WebQs(HFTask):
# TODO: make sure we're actually handling multi-answer correctly # TODO: make sure we're actually handling multi-answer correctly
return " " + doc['answers'][0] return " " + doc['answers'][0]
def construct_requests(self, doc, ctx): def _remove_prefixes(self, aliases):
""" Uses RequestFactory to construct Requests and returns an iterable of # Optimization: Remove any alias that has a strict prefix elsewhere in the list
Requests which will be sent to the LM. # we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
:param doc: return ret
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc: def construct_requests(self, doc, ctx):
The document as returned from training_docs, validation_docs, or test_docs. ret = []
:param results: for alias in self._remove_prefixes(doc['answers']):
The results of the requests created in construct_requests. _, is_prediction = rf.loglikelihood(ctx, " " + alias)
""" ret.append(is_prediction)
# TODO: implement evaluation. return ret
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results):
return {
"acc": float(any(results))
}
def aggregation(self): def aggregation(self):
""" return {
:returns: {str: [float] -> float} "acc": mean,
A dictionary where keys are the names of submetrics and values are }
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" return {
:returns: {str: bool} "acc": True
A dictionary where keys are the names of submetrics and values are }
whether a higher value of the submetric is better \ No newline at end of file
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
...@@ -32,8 +32,7 @@ def main(): ...@@ -32,8 +32,7 @@ def main():
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names)
# TODO: fall back to test docs task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
...@@ -50,7 +49,14 @@ def main(): ...@@ -50,7 +49,14 @@ def main():
# get lists of each type of requeste # get lists of each type of requeste
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)): #default to validation doc, fall back to test doc if validation unavailable
# TODO: the val-fallback-to-test system isn't final, we should revisit it at some point
if task.has_validation_docs():
task_doc_func = task.validation_docs
elif task.has_test_docs():
task_doc_func = task.test_docs
for doc_id, doc in enumerate(itertools.islice(task_doc_func(), 0, args.limit)):
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment