"backend/vscode:/vscode.git/clone" did not exist on "58bead039892136ac16e601d37e0dd87a3a75bf3"
Commit 34eb121f authored by Anthony DiPofi's avatar Anthony DiPofi
Browse files

add webqs evaluation and fallback to test set when validation is unavailable

parent 6598967b
...@@ -17,6 +17,7 @@ from . import lambada ...@@ -17,6 +17,7 @@ from . import lambada
from . import race from . import race
from . import piqa from . import piqa
from . import triviaqa from . import triviaqa
from . import webqs
TASK_REGISTRY = { TASK_REGISTRY = {
...@@ -55,7 +56,7 @@ TASK_REGISTRY = { ...@@ -55,7 +56,7 @@ TASK_REGISTRY = {
# "squad": squad.SQuAD, # not implemented yet # "squad": squad.SQuAD, # not implemented yet
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "webqs": webqs.WebQs, # not implemented yet "webqs": webqs.WebQs,
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet # "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
# "winogrande": winogrande.Winogrande, # not implemented yet # "winogrande": winogrande.Winogrande, # not implemented yet
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
......
from . common import HFTask from . common import HFTask
from lm_eval.base import mean, rf
class WebQs(HFTask): class WebQs(HFTask):
DATASET_PATH = "web_questions" DATASET_PATH = "web_questions"
...@@ -18,7 +19,6 @@ class WebQs(HFTask): ...@@ -18,7 +19,6 @@ class WebQs(HFTask):
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
print(doc)
return "Q: " + doc['question'] + '\nA:' return "Q: " + doc['question'] + '\nA:'
def doc_to_target(self, doc): def doc_to_target(self, doc):
...@@ -26,48 +26,37 @@ class WebQs(HFTask): ...@@ -26,48 +26,37 @@ class WebQs(HFTask):
# multiple correct answers being possible. # multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly # TODO: make sure we're actually handling multi-answer correctly
return " " + doc['answers'][0] return " " + doc['answers'][0]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of ret = []
Requests which will be sent to the LM. for alias in self._remove_prefixes(doc['answers']):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a return {
dict where keys are the names of submetrics and values are the values of "acc": float(any(results))
the metric for that one document }
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self): def aggregation(self):
""" return {
:returns: {str: [float] -> float} "acc": mean,
A dictionary where keys are the names of submetrics and values are }
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" return {
:returns: {str: bool} "acc": True
A dictionary where keys are the names of submetrics and values are }
whether a higher value of the submetric is better \ No newline at end of file
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
...@@ -32,8 +32,7 @@ def main(): ...@@ -32,8 +32,7 @@ def main():
task_names = args.tasks.split(",") task_names = args.tasks.split(",")
task_dict = tasks.get_task_dict(task_names) task_dict = tasks.get_task_dict(task_names)
# TODO: fall back to test docs task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
...@@ -50,7 +49,13 @@ def main(): ...@@ -50,7 +49,13 @@ def main():
# get lists of each type of requeste # get lists of each type of requeste
for task_name, task in task_dict_items: for task_name, task in task_dict_items:
for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)): #default to validation doc, fall back to test doc if validation unavailable
if task.has_validation_docs():
task_doc_func = task.validation_docs
elif task.has_test_docs():
task_doc_func = task.test_docs
for doc_id, doc in enumerate(itertools.islice(task_doc_func(), 0, args.limit)):
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment