Commit 88745155 authored by cjlovering's avatar cjlovering
Browse files

Initial integration

parent 6caa0afd
This diff is collapsed.
...@@ -6,21 +6,33 @@ import lm_eval.metrics ...@@ -6,21 +6,33 @@ import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
import promptsource
import numpy as np import numpy as np
from promptsource.templates import DatasetTemplates
from lm_eval.utils import positional_deprecated, run_task_tests from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(
num_fewshot=0, batch_size=None, device=None, model,
no_cache=False, limit=None, bootstrap_iters=100000, model_args=None,
description_dict=None, check_integrity=False): tasks=[],
num_fewshot=0,
batch_size=None,
device=None,
no_cache=False,
limit=None,
bootstrap_iters=100000,
description_dict=None,
check_integrity=False,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str] :param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string. String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object. Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]] :param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
...@@ -37,7 +49,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -37,7 +49,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool :param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks Whether to run the relevant part of the test suite for the tasks
:return :return
...@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert tasks != [], "No tasks specified" assert tasks != [], "No tasks specified"
if isinstance(model, str): if isinstance(model, str):
if model_args is None: model_args = "" if model_args is None:
lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, { model_args = ""
'batch_size': batch_size, 'device': device lm = lm_eval.models.get_model(model).create_from_arg_string(
}) model_args, {"batch_size": batch_size, "device": device}
)
else: else:
assert isinstance(model, lm_eval.base.LM) assert isinstance(model, lm_eval.base.LM)
lm = model lm = model
if not no_cache: if not no_cache:
lm = lm_eval.base.CachingLM( lm = lm_eval.base.CachingLM(
lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db' lm,
"lm_cache/"
+ model
+ "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks)
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict=task_dict, task_dict=task_dict,
num_fewshot=num_fewshot, num_fewshot=num_fewshot,
limit=limit, limit=limit,
description_dict=description_dict description_dict=description_dict,
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"description_dict": description_dict "description_dict": description_dict,
} }
return results return results
@positional_deprecated @positional_deprecated
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None): def evaluate(
lm,
task_dict,
provide_description=None,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
description_dict=None,
):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param lm: obj :param lm: obj
...@@ -108,7 +134,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -108,7 +134,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters: :param bootstrap_iters:
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert not provide_description # not implemented. assert not provide_description # not implemented.
if provide_description is not None: if provide_description is not None:
# nudge people to not specify it at all # nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items = [ task_dict_items = [
(name, task) (name, task)
for name, task in task_dict.items() for name, task in task_dict.items()
if(task.has_validation_docs() or task.has_test_docs()) if (task.has_validation_docs() or task.has_test_docs())
] ]
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
...@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
docs[(task_name, doc_id)] = doc docs[(task_name, doc_id)] = doc
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
num_fewshot=num_fewshot,
rnd=rnd,
description=description
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
...@@ -189,11 +218,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -189,11 +218,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print("Running", reqtype, "requests") print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs]) resps = getattr(lm, reqtype)([req.args for req in reqs])
resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)] resps = [
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp)) process_res_queue[(task_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
...@@ -207,25 +238,29 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -207,25 +238,29 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics = task.process_results(doc, requests) metrics = task.process_results(doc, requests)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_name, metric)].append(value)
task_name, prompt_name = task_name.split("+")
results[task_name]["task_name"] = task_name
results[task_name]["prompt_name"] = prompt_name
# aggregate results # aggregate results
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) results[task_name][metric] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters, bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
return { return {"results": dict(results), "versions": dict(versions)}
"results": dict(results),
"versions": dict(versions)
}
def make_table(result_dict): def make_table(result_dict):
...@@ -247,9 +282,9 @@ def make_table(result_dict): ...@@ -247,9 +282,9 @@ def make_table(result_dict):
if m + "_stderr" in dic: if m + "_stderr" in dic:
se = dic[m + "_stderr"] se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se]) values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
else: else:
values.append([k, version, m, '%.4f' % v, '', '']) values.append([k, version, m, "%.4f" % v, "", ""])
k = "" k = ""
version = "" version = ""
md_writer.value_matrix = values md_writer.value_matrix = values
......
from promptsource.templates import DatasetTemplates
from pprint import pprint from pprint import pprint
from typing import List, Union from typing import List, Union
...@@ -58,8 +60,8 @@ from . import storycloze ...@@ -58,8 +60,8 @@ from . import storycloze
# 6 total # 6 total
gpt3_translation_benchmarks = { gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French "wmt14": ["en-fr", "fr-en"], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
} }
...@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = { ...@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks = { selected_translation_benchmarks = {
**gpt3_translation_benchmarks, **gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"), "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic "iwslt17": ["en-ar", "ar-en"], # Arabic
} }
# 319 total # 319 total
...@@ -91,7 +93,7 @@ TASK_REGISTRY = { ...@@ -91,7 +93,7 @@ TASK_REGISTRY = {
"rte": glue.RTE, "rte": glue.RTE,
"qnli": glue.QNLI, "qnli": glue.QNLI,
"qqp": glue.QQP, "qqp": glue.QQP,
#"stsb": glue.STSB, # not implemented yet # "stsb": glue.STSB, # not implemented yet
"sst": glue.SST, "sst": glue.SST,
"wnli": glue.WNLI, "wnli": glue.WNLI,
# SuperGLUE # SuperGLUE
...@@ -102,34 +104,26 @@ TASK_REGISTRY = { ...@@ -102,34 +104,26 @@ TASK_REGISTRY = {
"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
# multilingual lambada # multilingual lambada
**lambada_multilingual.construct_tasks(), **lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText, "wikitext": wikitext.WikiText,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
"prost": prost.PROST, "prost": prost.PROST,
"mc_taco": mc_taco.MCTACO, "mc_taco": mc_taco.MCTACO,
# Science related # Science related
"pubmedqa" : pubmedqa.Pubmed_QA, "pubmedqa": pubmedqa.Pubmed_QA,
"sciq" : sciq.SciQ, "sciq": sciq.SciQ,
"qasper": qasper.QASPER, "qasper": qasper.QASPER,
"qa4mre_2011": qa4mre.QA4MRE_2011,
"qa4mre_2011" : qa4mre.QA4MRE_2011, "qa4mre_2012": qa4mre.QA4MRE_2012,
"qa4mre_2012" : qa4mre.QA4MRE_2012, "qa4mre_2013": qa4mre.QA4MRE_2013,
"qa4mre_2013" : qa4mre.QA4MRE_2013,
"triviaqa": triviaqa.TriviaQA, "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy, "arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge, "arc_challenge": arc.ARCChallenge,
...@@ -140,7 +134,7 @@ TASK_REGISTRY = { ...@@ -140,7 +134,7 @@ TASK_REGISTRY = {
"squad2": squad.SQuAD2, "squad2": squad.SQuAD2,
"race": race.RACE, "race": race.RACE,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet # "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
"headqa_es": headqa.HeadQAEs, "headqa_es": headqa.HeadQAEs,
"headqa_en": headqa.HeadQAEn, "headqa_en": headqa.HeadQAEn,
"mathqa": mathqa.MathQA, "mathqa": mathqa.MathQA,
...@@ -150,21 +144,17 @@ TASK_REGISTRY = { ...@@ -150,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
"ethics_cm": hendrycks_ethics.EthicsCM, "ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_justice": hendrycks_ethics.EthicsJustice,
"ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal, "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
"ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism, "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
"ethics_virtue": hendrycks_ethics.EthicsVirtue, "ethics_virtue": hendrycks_ethics.EthicsVirtue,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
"truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice, "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
"truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# dialogue # dialogue
"mutual": mutual.MuTual, "mutual": mutual.MuTual,
"mutual_plus": mutual.MuTualPlus, "mutual_plus": mutual.MuTualPlus,
# math # math
"math_algebra": hendrycks_math.MathAlgebra, "math_algebra": hendrycks_math.MathAlgebra,
"math_counting_and_prob": hendrycks_math.MathCountingAndProbability, "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
...@@ -175,7 +165,6 @@ TASK_REGISTRY = { ...@@ -175,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc": hendrycks_math.MathPrecalculus, "math_precalc": hendrycks_math.MathPrecalculus,
"math_asdiv": asdiv.Asdiv, "math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K, "gsm8k": gsm8k.GradeSchoolMath8K,
# arithmetic # arithmetic
"arithmetic_2da": arithmetic.Arithmetic2DPlus, "arithmetic_2da": arithmetic.Arithmetic2DPlus,
"arithmetic_2ds": arithmetic.Arithmetic2DMinus, "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
...@@ -189,22 +178,18 @@ TASK_REGISTRY = { ...@@ -189,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks # TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations # e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks) # hendrycksTest (57 tasks)
**hendrycks_test.create_all_tasks(), **hendrycks_test.create_all_tasks(),
# e.g. wmt14-fr-en # e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks), **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20 # chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks), **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# Word Scrambling and Manipulation Tasks # Word Scrambling and Manipulation Tasks
"anagrams1": unscramble.Anagrams1, "anagrams1": unscramble.Anagrams1,
"anagrams2": unscramble.Anagrams2, "anagrams2": unscramble.Anagrams2,
"cycle_letters": unscramble.CycleLetters, "cycle_letters": unscramble.CycleLetters,
"random_insertion": unscramble.RandomInsertion, "random_insertion": unscramble.RandomInsertion,
"reversed_words": unscramble.ReversedWords, "reversed_words": unscramble.ReversedWords,
# Pile # Pile
"pile_arxiv": pile.PileArxiv, "pile_arxiv": pile.PileArxiv,
"pile_books3": pile.PileBooks3, "pile_books3": pile.PileBooks3,
...@@ -228,7 +213,6 @@ TASK_REGISTRY = { ...@@ -228,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc": pile.PileUbuntuIrc, "pile_ubuntu-irc": pile.PileUbuntuIrc,
"pile_wikipedia": pile.PileWikipedia, "pile_wikipedia": pile.PileWikipedia,
"pile_youtubesubtitles": pile.PileYoutubeSubtitles, "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# BLiMP # BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland, "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement, "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
...@@ -297,7 +281,6 @@ TASK_REGISTRY = { ...@@ -297,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# Requires manual download of data. # Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018, # "storycloze_2018": storycloze.StoryCloze2018,
...@@ -321,19 +304,43 @@ def get_task_name_from_object(task_object): ...@@ -321,19 +304,43 @@ def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items(): for name, class_ in TASK_REGISTRY.items():
if class_ is task_object: if class_ is task_object:
return name return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__ return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]):
task_name_dict = { task_name_dict = {
task_name: get_task(task_name)() task_name: get_task(task_name)()
for task_name in task_name_list if isinstance(task_name, str) for task_name in task_name_list
if isinstance(task_name, str)
} }
task_name_from_object_dict = { task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object get_task_name_from_object(task_object): task_object
for task_object in task_name_list if not isinstance(task_object, str) for task_object in task_name_list
if not isinstance(task_object, str)
} }
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {**task_name_dict, **task_name_from_object_dict} return {**task_name_dict, **task_name_from_object_dict}
def get_task_dict_promptsource(task_name_list: List[str]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict = {}
for task_name in task_name_list:
assert isinstance(task_name, str)
task_prompts = DatasetTemplates(task_name)
for prompt_name in task_prompts.all_template_names:
prompt = task_prompts[prompt_name]
# NOTE: We choose a sep that can be easily split.
task_name_dict[f"{task_name}+{prompt_name}"] = get_task(task_name)(
prompt=prompt
)
return task_name_dict
...@@ -51,44 +51,22 @@ class CoQA(Task): ...@@ -51,44 +51,22 @@ class CoQA(Task):
def test_docs(self): def test_docs(self):
pass pass
def doc_to_text(self, doc): # @classmethod
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} # def get_answers(cls, doc, turn_id):
# and a question qi, the task is to predict the answer ai # # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
doc_text = doc["story"] + '\n\n' # answers = []
for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]): # omit target answer ai # answer_forturn = doc["answers"]["input_text"][turn_id - 1]
question = f"Q: {q}\n\n" # answers.append(answer_forturn)
answer = f"A: {a}\n\n" if a is not None else "A:"
doc_text += question + answer # additional_answers = doc.get("additional_answers")
return doc_text # if additional_answers:
# for key in additional_answers:
@classmethod # additional_answer_for_turn = additional_answers[key]["input_text"][
def get_answers(cls, doc, turn_id): # turn_id - 1
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers). # ]
answers = [] # if additional_answer_for_turn.lower() not in map(str.lower, answers):
answer_forturn = doc["answers"]["input_text"][turn_id - 1] # answers.append(additional_answer_for_turn)
answers.append(answer_forturn) # return answers
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
@classmethod
def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
@staticmethod @staticmethod
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
...@@ -98,40 +76,38 @@ class CoQA(Task): ...@@ -98,40 +76,38 @@ class CoQA(Task):
em_sum = 0.0 em_sum = 0.0
if len(gold_list) > 1: if len(gold_list) > 1:
for i in range(len(gold_list)): for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:] gold_answers = gold_list[0:i] + gold_list[i + 1 :]
# predictions compared against (n) golds and take maximum # predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers) em_sum += max(
squad_metrics.compute_exact(a, pred) for a in gold_answers
)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else: else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list) em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))} return {
"em": em_sum / max(1, len(gold_list)),
def doc_to_target(self, doc, turnid=None): "f1": f1_sum / max(1, len(gold_list)),
# Default to prediction of last turn. }
if turnid is None:
turnid = len(doc["questions"]["input_text"])
raw_text = doc['answers']["input_text"][turnid - 1]
return " " + raw_text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
cont_request = rf.greedy_until(ctx, ['\nQ:']) cont_request = rf.greedy_until(ctx, ["\nQ:"])
return cont_request return cont_request
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -139,15 +115,18 @@ class CoQA(Task): ...@@ -139,15 +115,18 @@ class CoQA(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
turn_id = len(doc["questions"]["input_text"]) target = self.doc_to_target(doc).strip()
gold_list = self.get_answers(doc, turn_id) pred = results[0].strip().split("\n")[0]
pred = results[0].strip().split('\n')[0]
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
scores = self.compute_scores(gold_list, pred) # TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred)
return { return {
"f1": scores['f1'], "f1": scores["f1"],
"em": scores['em'], "em": scores["em"],
} }
def higher_is_better(self): def higher_is_better(self):
......
...@@ -70,21 +70,26 @@ class DROP(Task): ...@@ -70,21 +70,26 @@ class DROP(Task):
@classmethod @classmethod
def get_answers(cls, qa): def get_answers(cls, qa):
def _flatten_validated_answers(validated_answers): def _flatten_validated_answers(validated_answers):
""" Flattens a dict of lists of validated answers. """Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...} {"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}] -> [{"number": ['1'], ...}, {"number": ['8'], ...}]
""" """
vas = [] vas = []
for i in range(len(validated_answers["number"])): for i in range(len(validated_answers["number"])):
vas.append({ vas.append(
"number": validated_answers["number"][i], {
"date": validated_answers["date"][i], "number": validated_answers["number"][i],
"spans": validated_answers["spans"][i], "date": validated_answers["date"][i],
}) "spans": validated_answers["spans"][i],
}
)
return vas return vas
answers = [] answers = []
answers_set = set() answers_set = set()
candidates = [qa["answer"]] + _flatten_validated_answers(qa["validated_answers"]) candidates = [qa["answer"]] + _flatten_validated_answers(
qa["validated_answers"]
)
for candidate in candidates: for candidate in candidates:
answer = cls.parse_answer(candidate) answer = cls.parse_answer(candidate)
if answer in answers_set: if answer in answers_set:
...@@ -100,15 +105,17 @@ class DROP(Task): ...@@ -100,15 +105,17 @@ class DROP(Task):
return (str(answer["number"]),) return (str(answer["number"]),)
if answer["spans"] != []: if answer["spans"] != []:
return tuple(answer["spans"]) return tuple(answer["spans"])
return (" ".join([answer["date"]["day"], return (
answer["date"]["month"], " ".join(
answer["date"]["year"]]).strip(),) [answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
def doc_to_text(self, doc): # def doc_to_text(self, doc):
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:" # return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc): # def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"][0]) # return " " + ", ".join(doc["answers"][0])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
...@@ -134,7 +141,13 @@ class DROP(Task): ...@@ -134,7 +141,13 @@ class DROP(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
preds, golds = results, doc["answers"]
pred = results[0].strip()
target = self.doc_to_target(doc).strip()
preds = [pred]
golds = [target]
max_em = 0 max_em = 0
max_f1 = 0 max_f1 = 0
for gold_answer in golds: for gold_answer in golds:
...@@ -142,10 +155,7 @@ class DROP(Task): ...@@ -142,10 +155,7 @@ class DROP(Task):
if gold_answer[0].strip(): if gold_answer[0].strip():
max_em = max(max_em, exact_match) max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score) max_f1 = max(max_f1, f1_score)
return { return {"em": max_em, "f1": max_f1}
"em": max_em,
"f1": max_f1
}
def get_metrics(self, predicted, gold): def get_metrics(self, predicted, gold):
""" """
...@@ -158,7 +168,9 @@ class DROP(Task): ...@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags = self._answer_to_bags(predicted) predicted_bags = self._answer_to_bags(predicted)
gold_bags = self._answer_to_bags(gold) gold_bags = self._answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): if set(predicted_bags[0]) == set(gold_bags[0]) and len(
predicted_bags[0]
) == len(gold_bags[0]):
exact_match = 1.0 exact_match = 1.0
else: else:
exact_match = 0.0 exact_match = 0.0
...@@ -190,7 +202,9 @@ class DROP(Task): ...@@ -190,7 +202,9 @@ class DROP(Task):
for gold_index, gold_item in enumerate(gold): for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted): for pred_index, pred_item in enumerate(predicted):
if self._match_numbers_if_present(gold_item, pred_item): if self._match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item) scores[gold_index, pred_index] = self._compute_f1(
pred_item, gold_item
)
row_ind, col_ind = linear_sum_assignment(-scores) row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))]) max_scores = np.zeros([max(len(gold), len(predicted))])
...@@ -256,7 +270,11 @@ class DROP(Task): ...@@ -256,7 +270,11 @@ class DROP(Task):
def _normalize(self, answer): def _normalize(self, answer):
tokens = [ tokens = [
self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower())))) self._white_space_fix(
self._remove_articles(
self._fix_number(self._remove_punc(token.lower()))
)
)
for token in self._tokenize(answer) for token in self._tokenize(answer)
] ]
tokens = [token for token in tokens if token.strip()] tokens = [token for token in tokens if token.strip()]
...@@ -269,10 +287,7 @@ class DROP(Task): ...@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"em": mean, "f1": mean}
"em": mean,
"f1": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -280,7 +295,4 @@ class DROP(Task): ...@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"em": True, "f1": True}
"em": True,
"f1": True
}
...@@ -40,7 +40,7 @@ class RACE(Task): ...@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME = "high" DATASET_NAME = "high"
cache = {} cache = {}
letter_to_num = {'A': 0, 'B': 1, 'C': 2, 'D': 3} letter_to_num = {"A": 0, "B": 1, "C": 2, "D": 3}
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -59,17 +59,27 @@ class RACE(Task): ...@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage. # is shown that one document is made per passage.
r = collections.defaultdict(list) r = collections.defaultdict(list)
for item in datasets.load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)[set]: for item in datasets.load_dataset(
r[item['article']].append(item) path=self.DATASET_PATH, name=self.DATASET_NAME
)[set]:
res = list(r.values() >> each(lambda x: { r[item["article"]].append(item)
'article': x[0]['article'],
'problems': x >> each(lambda y: { res = list(
'question': y['question'], r.values()
'answer': y['answer'], >> each(
'options': y['options'], lambda x: {
}) "article": x[0]["article"],
})) "problems": x
>> each(
lambda y: {
"question": y["question"],
"answer": y["answer"],
"options": y["options"],
}
),
}
)
)
self.cache[set] = res self.cache[set] = res
return res return res
...@@ -85,49 +95,48 @@ class RACE(Task): ...@@ -85,49 +95,48 @@ class RACE(Task):
@classmethod @classmethod
def get_answer_option(cls, problem): def get_answer_option(cls, problem):
answer = cls.letter_to_num[problem['answer']] answer = cls.letter_to_num[problem["answer"]]
return problem['options'][answer] return problem["options"][answer]
@classmethod @classmethod
def last_problem(cls, doc): def last_problem(cls, doc):
return doc['problems'][-1] return doc["problems"][-1]
def doc_to_text(self, doc): # def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n' # text = 'Article: ' + doc['article'] + '\n\n'
for problem in doc['problems'][:-1]: # for problem in doc['problems'][:-1]:
if problem['question'][-6:] == ' _ .': # if problem['question'][-6:] == ' _ .':
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n' # text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
else: # else:
question = 'Question: ' + problem['question'] + '\n' # question = 'Question: ' + problem['question'] + '\n'
answer = 'Answer: ' + self.get_answer_option(problem) + '\n' # answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text += question + answer # text += question + answer
text += self.last_problem(doc)['question'] # text += self.last_problem(doc)['question']
return text # return text
def doc_to_target(self, doc): # def doc_to_target(self, doc):
return " " + self.get_answer_option(self.last_problem(doc)) # return " " + self.get_answer_option(self.last_problem(doc))
def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. # Requests which will be sent to the LM.
:param doc: # :param doc:
The document as returned from training_docs, validation_docs, or test_docs. # The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str # :param ctx: str
The context string, generated by fewshot_context. This includes the natural # The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question # language description, as well as the few shot examples, and the question
part of the document for `doc`. # part of the document for `doc`.
""" # """
problem = self.last_problem(doc) # problem = self.last_problem(doc)
ll_choices = [ # ll_choices = [
rf.loglikelihood(ctx, " " + problem['options'][i])[0] # rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
for i in range(4) # ]
] # return ll_choices
return ll_choices
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -135,28 +144,24 @@ class RACE(Task): ...@@ -135,28 +144,24 @@ class RACE(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
gold = self.letter_to_num[self.last_problem(doc)['answer']] #
gold = self.letter_to_num[self.doc_to_target(doc)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": int(pred == gold)}
"acc": int(pred == gold)
}
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {"acc": mean}
"acc": mean
}
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {"acc": True}
"acc": True
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment