Unverified Commit bc5495d2 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge branch 'master' into wsc273-evaluation

parents b5e86d3f e12d0078
......@@ -17,6 +17,7 @@ from . import lambada
from . import race
from . import piqa
from . import triviaqa
from . import webqs
TASK_REGISTRY = {
......@@ -55,7 +56,7 @@ TASK_REGISTRY = {
# "squad": squad.SQuAD, # not implemented yet
"race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "webqs": webqs.WebQs, # not implemented yet
"webqs": webqs.WebQs,
"wsc273": wsc273.WinogradSchemaChallenge273,
# "winogrande": winogrande.Winogrande, # not implemented yet
"anli_r1": anli.ANLIRound1,
......
from . common import HFTask
from lm_eval.base import mean, rf
class WebQs(HFTask):
DATASET_PATH = "web_questions"
......@@ -18,7 +19,6 @@ class WebQs(HFTask):
return ""
def doc_to_text(self, doc):
print(doc)
return "Q: " + doc['question'] + '\nA:'
def doc_to_target(self, doc):
......@@ -26,48 +26,37 @@ class WebQs(HFTask):
# multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly
return " " + doc['answers'][0]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
ret = []
for alias in self._remove_prefixes(doc['answers']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": float(any(results))
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": mean,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"acc": True
}
\ No newline at end of file
......@@ -32,8 +32,7 @@ def main():
task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names)
# TODO: fall back to test docs
task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
results = collections.defaultdict(dict)
......@@ -50,7 +49,14 @@ def main():
# get lists of each type of requeste
for task_name, task in task_dict_items:
for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)):
#default to validation doc, fall back to test doc if validation unavailable
# TODO: the val-fallback-to-test system isn't final, we should revisit it at some point
if task.has_validation_docs():
task_doc_func = task.validation_docs
elif task.has_test_docs():
task_doc_func = task.test_docs
for doc_id, doc in enumerate(itertools.islice(task_doc_func(), 0, args.limit)):
docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment