Commit f275301a authored by haileyschoelkopf's avatar haileyschoelkopf Committed by Hailey Schoelkopf
Browse files

make tasks and models registered by decorators

parent e7c18e53
......@@ -2,6 +2,29 @@ import abc
from lm_eval import utils
MODEL_REGISTRY = {}
def register_model(name):
# TODO: should fairseq/elk be cited for this design pattern?
def decorate(cls):
assert (
issubclass(cls, LM)
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model!"
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
return MODEL_REGISTRY[model_name]
class LM(abc.ABC):
def __init__(self):
......
......@@ -9,6 +9,8 @@ import itertools
import datasets
import numpy as np
from typing import List, Union
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
......@@ -31,7 +33,7 @@ class TaskConfig(dict):
# TODO: add this as more jinja2 appended to start of jinja2 templates. Should allow users to set vars
# s.t. they can define e.g. {% set question = query %} to map dataset columns to "canonical" names in prompts.
template_vars: str = None
template_aliases: str = None
doc_to_text: str = None
doc_to_target: str = None
......@@ -609,3 +611,82 @@ class PerplexityTask(Task, abc.ABC):
def count_words(cls, doc):
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
# TODO: confirm we want this to go in this file
TASK_REGISTRY = {}
ALL_TASKS = []
def register_task(name):
def decorate(cls):
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task!"
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import right.
return cls
return decorate
##### Task registry utils and setup.
# ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name):
try:
return TASK_REGISTRY[task_name]
except KeyError:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
task_name_dict = {
task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name})
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_config_dict = {
get_task_name_from_config(task_config): ConfigurableTask(
config=task_config
)
for task_config in task_name_list
if isinstance(task_config, dict)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list
if isinstance(task_object, Task)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {
**task_name_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
\ No newline at end of file
......@@ -58,14 +58,14 @@ def simple_evaluate(
if isinstance(model, str):
if model_args is None:
model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string(
lm = lm_eval.api.model.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device}
)
else:
assert isinstance(model, lm_eval.api.model.LM)
lm = model
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot)
task_dict = lm_eval.api.task.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity:
run_task_tests(task_list=tasks)
......
from lm_eval.api.model import LM, MODEL_REGISTRY
from . import gpt2
from . import gpt3
from . import textsynth
from . import dummy
MODEL_REGISTRY = {
"hf-causal": gpt2.HFLM,
"openai": gpt3.GPT3LM,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
}
# MODEL_REGISTRY = {}
# MODEL_REGISTRY = {
# "hf-causal": gpt2.HFLM,
# "openai": gpt3.GPT3LM,
# "textsynth": textsynth.TextSynthLM,
# "dummy": dummy.DummyLM,
# }
def get_model(model_name):
return MODEL_REGISTRY[model_name]
# def get_model(model_name):
# return MODEL_REGISTRY[model_name]
......@@ -6,9 +6,11 @@ from tqdm import tqdm
import torch.nn.functional as F
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.model import LM, register_model
# from lm_eval.models import register_model
@register_model("hf-causal")
class HFLM(LM):
def __init__(
self,
......
......@@ -53,316 +53,316 @@ from . import gsm8k
# from . import toxigen
# from . import crowspairs
########################################
# Translation tasks
########################################
# ########################################
# # Translation tasks
# ########################################
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ["en-fr", "fr-en"], # French
"wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
# # 6 total
# gpt3_translation_benchmarks = {
# "wmt14": ["en-fr", "fr-en"], # French
# "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
# }
# 28 total
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ["en-ar", "ar-en"], # Arabic
}
# # 28 total
# selected_translation_benchmarks = {
# **gpt3_translation_benchmarks,
# "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
# "iwslt17": ["en-ar", "ar-en"], # Arabic
# }
# 319 total
all_translation_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
# # 319 total
# all_translation_benchmarks = {
# ts: sacrebleu.get_langpairs_for_testset(ts)
# for ts in sacrebleu.get_available_testsets()
# }
########################################
# All tasks
########################################
# ########################################
# # All tasks
# ########################################
TASK_REGISTRY = {
# GLUE
# "cola": glue.CoLA,
# "mnli": glue.MNLI,
# "mnli_mismatched": glue.MNLIMismatched,
# "mrpc": glue.MRPC,
# "rte": glue.RTE,
# "qnli": glue.QNLI,
# "qqp": glue.QQP,
# # "stsb": glue.STSB, # not implemented yet
# "sst": glue.SST,
# "wnli": glue.WNLI,
# # SuperGLUE
# "boolq": superglue.BoolQ,
# "cb": superglue.CommitmentBank,
# "copa": superglue.Copa,
# "multirc": superglue.MultiRC,
# "record": superglue.ReCoRD,
# "wic": superglue.WordsInContext,
# "wsc": superglue.SGWinogradSchemaChallenge,
# # Order by benchmark/genre?
# "coqa": coqa.CoQA,
# "drop": drop.DROP,
"lambada_openai": lambada.LambadaOpenAI,
"lambada_standard": lambada.LambadaStandard,
# "lambada_openai_cloze": lambada_cloze.LambadaOpenAICloze,
# "lambada_standard_cloze": lambada_cloze.LambadaStandardCloze,
# # multilingual lambada
# **lambada_multilingual.construct_tasks(),
"wikitext": wikitext.WikiText,
# # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
# "piqa": piqa.PiQA,
# "prost": prost.PROST,
# "mc_taco": mc_taco.MCTACO,
# # Science related
# "pubmedqa": pubmedqa.Pubmed_QA,
# "sciq": sciq.SciQ,
# "qasper": qasper.QASPER,
# "qa4mre_2011": qa4mre.QA4MRE_2011,
# "qa4mre_2012": qa4mre.QA4MRE_2012,
# "qa4mre_2013": qa4mre.QA4MRE_2013,
# "triviaqa": triviaqa.TriviaQA,
"arc_easy": arc.ARCEasy,
"arc_challenge": arc.ARCChallenge,
# # "quac": quac.QuAC, # not implemented yet
# "logiqa": logiqa.LogiQA,
# "hellaswag": hellaswag.HellaSwag,
# "swag": swag.SWAG,
# "openbookqa": openbookqa.OpenBookQA,
# "squad2": squad.SQuAD2,
# "race": race.RACE,
# # "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
# "headqa_es": headqa.HeadQAEs,
# "headqa_en": headqa.HeadQAEn,
# "mathqa": mathqa.MathQA,
# "webqs": webqs.WebQs,
# "wsc273": wsc273.WinogradSchemaChallenge273,
# "winogrande": winogrande.Winogrande,
# "anli_r1": anli.ANLIRound1,
# "anli_r2": anli.ANLIRound2,
# "anli_r3": anli.ANLIRound3,
# "ethics_cm": hendrycks_ethics.EthicsCM,
# "ethics_deontology": hendrycks_ethics.EthicsDeontology,
# "ethics_justice": hendrycks_ethics.EthicsJustice,
# "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
# "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
# "ethics_virtue": hendrycks_ethics.EthicsVirtue,
# "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
# "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# # dialogue
# "mutual": mutual.MuTual,
# "mutual_plus": mutual.MuTualPlus,
# # math
# "math_algebra": hendrycks_math.MathAlgebra,
# "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
# "math_geometry": hendrycks_math.MathGeometry,
# "math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
# "math_num_theory": hendrycks_math.MathNumberTheory,
# "math_prealgebra": hendrycks_math.MathPrealgebra,
# "math_precalc": hendrycks_math.MathPrecalculus,
# "math_asdiv": asdiv.Asdiv,
"gsm8k": gsm8k.GradeSchoolMath8K,
# # arithmetic
# "arithmetic_2da": arithmetic.Arithmetic2DPlus,
# "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
# "arithmetic_3da": arithmetic.Arithmetic3DPlus,
# "arithmetic_3ds": arithmetic.Arithmetic3DMinus,
# "arithmetic_4da": arithmetic.Arithmetic4DPlus,
# "arithmetic_4ds": arithmetic.Arithmetic4DMinus,
# "arithmetic_5da": arithmetic.Arithmetic5DPlus,
# "arithmetic_5ds": arithmetic.Arithmetic5DMinus,
# "arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
# "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# # TODO Perhaps make these groups of tasks
# # e.g. anli, arithmetic, openai_translations, harness_translations
# # hendrycksTest (57 tasks)
# **hendrycks_test.create_all_tasks(),
# # e.g. wmt14-fr-en
# **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# # chef's selection, mostly wmt20
# **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# # Word Scrambling and Manipulation Tasks
# "anagrams1": unscramble.Anagrams1,
# "anagrams2": unscramble.Anagrams2,
# "cycle_letters": unscramble.CycleLetters,
# "random_insertion": unscramble.RandomInsertion,
# "reversed_words": unscramble.ReversedWords,
# # Pile
# "pile_arxiv": pile.PileArxiv,
# "pile_books3": pile.PileBooks3,
# "pile_bookcorpus2": pile.PileBookCorpus2,
# "pile_dm-mathematics": pile.PileDmMathematics,
# "pile_enron": pile.PileEnron,
# "pile_europarl": pile.PileEuroparl,
# "pile_freelaw": pile.PileFreeLaw,
# "pile_github": pile.PileGithub,
# "pile_gutenberg": pile.PileGutenberg,
# "pile_hackernews": pile.PileHackernews,
# "pile_nih-exporter": pile.PileNIHExporter,
# "pile_opensubtitles": pile.PileOpenSubtitles,
# "pile_openwebtext2": pile.PileOpenWebText2,
# "pile_philpapers": pile.PilePhilPapers,
# "pile_pile-cc": pile.PilePileCc,
# "pile_pubmed-abstracts": pile.PilePubmedAbstracts,
# "pile_pubmed-central": pile.PilePubmedCentral,
# "pile_stackexchange": pile.PileStackExchange,
# "pile_uspto": pile.PileUspto,
# "pile_ubuntu-irc": pile.PileUbuntuIrc,
# "pile_wikipedia": pile.PileWikipedia,
# "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# # BLiMP
# "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
# "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
# "blimp_anaphor_number_agreement": blimp.BlimpAnaphorNumberAgreement,
# "blimp_animate_subject_passive": blimp.BlimpAnimateSubjectPassive,
# "blimp_animate_subject_trans": blimp.BlimpAnimateSubjectTrans,
# "blimp_causative": blimp.BlimpCausative,
# "blimp_complex_NP_island": blimp.BlimpComplex_NPIsland,
# "blimp_coordinate_structure_constraint_complex_left_branch": blimp.BlimpCoordinateStructureConstraintComplexLeftBranch,
# "blimp_coordinate_structure_constraint_object_extraction": blimp.BlimpCoordinateStructureConstraintObjectExtraction,
# "blimp_determiner_noun_agreement_1": blimp.BlimpDeterminerNounAgreement_1,
# "blimp_determiner_noun_agreement_2": blimp.BlimpDeterminerNounAgreement_2,
# "blimp_determiner_noun_agreement_irregular_1": blimp.BlimpDeterminerNounAgreementIrregular_1,
# "blimp_determiner_noun_agreement_irregular_2": blimp.BlimpDeterminerNounAgreementIrregular_2,
# "blimp_determiner_noun_agreement_with_adj_2": blimp.BlimpDeterminerNounAgreementWithAdj_2,
# "blimp_determiner_noun_agreement_with_adj_irregular_1": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_1,
# "blimp_determiner_noun_agreement_with_adj_irregular_2": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_2,
# "blimp_determiner_noun_agreement_with_adjective_1": blimp.BlimpDeterminerNounAgreementWithAdjective_1,
# "blimp_distractor_agreement_relational_noun": blimp.BlimpDistractorAgreementRelationalNoun,
# "blimp_distractor_agreement_relative_clause": blimp.BlimpDistractorAgreementRelativeClause,
# "blimp_drop_argument": blimp.BlimpDropArgument,
# "blimp_ellipsis_n_bar_1": blimp.BlimpEllipsisNBar_1,
# "blimp_ellipsis_n_bar_2": blimp.BlimpEllipsisNBar_2,
# "blimp_existential_there_object_raising": blimp.BlimpExistentialThereObjectRaising,
# "blimp_existential_there_quantifiers_1": blimp.BlimpExistentialThereQuantifiers_1,
# "blimp_existential_there_quantifiers_2": blimp.BlimpExistentialThereQuantifiers_2,
# "blimp_existential_there_subject_raising": blimp.BlimpExistentialThereSubjectRaising,
# "blimp_expletive_it_object_raising": blimp.BlimpExpletiveItObjectRaising,
# "blimp_inchoative": blimp.BlimpInchoative,
# "blimp_intransitive": blimp.BlimpIntransitive,
# "blimp_irregular_past_participle_adjectives": blimp.BlimpIrregularPastParticipleAdjectives,
# "blimp_irregular_past_participle_verbs": blimp.BlimpIrregularPastParticipleVerbs,
# "blimp_irregular_plural_subject_verb_agreement_1": blimp.BlimpIrregularPluralSubjectVerbAgreement_1,
# "blimp_irregular_plural_subject_verb_agreement_2": blimp.BlimpIrregularPluralSubjectVerbAgreement_2,
# "blimp_left_branch_island_echo_question": blimp.BlimpLeftBranchIslandEchoQuestion,
# "blimp_left_branch_island_simple_question": blimp.BlimpLeftBranchIslandSimpleQuestion,
# "blimp_matrix_question_npi_licensor_present": blimp.BlimpMatrixQuestionNpiLicensorPresent,
# "blimp_npi_present_1": blimp.BlimpNpiPresent_1,
# "blimp_npi_present_2": blimp.BlimpNpiPresent_2,
# "blimp_only_npi_licensor_present": blimp.BlimpOnlyNpiLicensorPresent,
# "blimp_only_npi_scope": blimp.BlimpOnlyNpiScope,
# "blimp_passive_1": blimp.BlimpPassive_1,
# "blimp_passive_2": blimp.BlimpPassive_2,
# "blimp_principle_A_c_command": blimp.BlimpPrinciple_ACCommand,
# "blimp_principle_A_case_1": blimp.BlimpPrinciple_ACase_1,
# "blimp_principle_A_case_2": blimp.BlimpPrinciple_ACase_2,
# "blimp_principle_A_domain_1": blimp.BlimpPrinciple_ADomain_1,
# "blimp_principle_A_domain_2": blimp.BlimpPrinciple_ADomain_2,
# "blimp_principle_A_domain_3": blimp.BlimpPrinciple_ADomain_3,
# "blimp_principle_A_reconstruction": blimp.BlimpPrinciple_AReconstruction,
# "blimp_regular_plural_subject_verb_agreement_1": blimp.BlimpRegularPluralSubjectVerbAgreement_1,
# "blimp_regular_plural_subject_verb_agreement_2": blimp.BlimpRegularPluralSubjectVerbAgreement_2,
# "blimp_sentential_negation_npi_licensor_present": blimp.BlimpSententialNegationNpiLicensorPresent,
# "blimp_sentential_negation_npi_scope": blimp.BlimpSententialNegationNpiScope,
# "blimp_sentential_subject_island": blimp.BlimpSententialSubjectIsland,
# "blimp_superlative_quantifiers_1": blimp.BlimpSuperlativeQuantifiers_1,
# "blimp_superlative_quantifiers_2": blimp.BlimpSuperlativeQuantifiers_2,
# "blimp_tough_vs_raising_1": blimp.BlimpToughVsRaising_1,
# "blimp_tough_vs_raising_2": blimp.BlimpToughVsRaising_2,
# "blimp_transitive": blimp.BlimpTransitive,
# "blimp_wh_island": blimp.BlimpWhIsland,
# "blimp_wh_questions_object_gap": blimp.BlimpWhQuestionsObjectGap,
# "blimp_wh_questions_subject_gap": blimp.BlimpWhQuestionsSubjectGap,
# "blimp_wh_questions_subject_gap_long_distance": blimp.BlimpWhQuestionsSubjectGapLongDistance,
# "blimp_wh_vs_that_no_gap": blimp.BlimpWhVsThatNoGap,
# "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
# "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
# "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# "toxigen": toxigen.ToxiGen,
# "crows_pairs_english": crowspairs.CrowsPairsEnglish,
# "crows_pairs_english_race_color": crowspairs.CrowsPairsEnglishRaceColor,
# "crows_pairs_english_socioeconomic": crowspairs.CrowsPairsEnglishSocioeconomic,
# "crows_pairs_english_gender": crowspairs.CrowsPairsEnglishGender,
# "crows_pairs_english_age": crowspairs.CrowsPairsEnglishAge,
# "crows_pairs_english_religion": crowspairs.CrowsPairsEnglishReligion,
# "crows_pairs_english_disability": crowspairs.CrowsPairsEnglishDisability,
# "crows_pairs_english_sexual_orientation": crowspairs.CrowsPairsEnglishSexualOrientation,
# "crows_pairs_english_nationality": crowspairs.CrowsPairsEnglishNationality,
# "crows_pairs_english_physical_appearance": crowspairs.CrowsPairsEnglishPhysicalAppearance,
# "crows_pairs_english_autre": crowspairs.CrowsPairsEnglishAutre,
# "crows_pairs_french": crowspairs.CrowsPairsFrench,
# "crows_pairs_french_race_color": crowspairs.CrowsPairsFrenchRaceColor,
# "crows_pairs_french_socioeconomic": crowspairs.CrowsPairsFrenchSocioeconomic,
# "crows_pairs_french_gender": crowspairs.CrowsPairsFrenchGender,
# "crows_pairs_french_age": crowspairs.CrowsPairsFrenchAge,
# "crows_pairs_french_religion": crowspairs.CrowsPairsFrenchReligion,
# "crows_pairs_french_disability": crowspairs.CrowsPairsFrenchDisability,
# "crows_pairs_french_sexual_orientation": crowspairs.CrowsPairsFrenchSexualOrientation,
# "crows_pairs_french_nationality": crowspairs.CrowsPairsFrenchNationality,
# "crows_pairs_french_physical_appearance": crowspairs.CrowsPairsFrenchPhysicalAppearance,
# "crows_pairs_french_autre": crowspairs.CrowsPairsFrenchAutre,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "sat": sat.SATAnalogies,
}
# TASK_REGISTRY = {
# # GLUE
# # "cola": glue.CoLA,
# # "mnli": glue.MNLI,
# # "mnli_mismatched": glue.MNLIMismatched,
# # "mrpc": glue.MRPC,
# # "rte": glue.RTE,
# # "qnli": glue.QNLI,
# # "qqp": glue.QQP,
# # # "stsb": glue.STSB, # not implemented yet
# # "sst": glue.SST,
# # "wnli": glue.WNLI,
# # # SuperGLUE
# # "boolq": superglue.BoolQ,
# # "cb": superglue.CommitmentBank,
# # "copa": superglue.Copa,
# # "multirc": superglue.MultiRC,
# # "record": superglue.ReCoRD,
# # "wic": superglue.WordsInContext,
# # "wsc": superglue.SGWinogradSchemaChallenge,
# # # Order by benchmark/genre?
# # "coqa": coqa.CoQA,
# # "drop": drop.DROP,
# "lambada_openai": lambada.LambadaOpenAI,
# "lambada_standard": lambada.LambadaStandard,
# # "lambada_openai_cloze": lambada_cloze.LambadaOpenAICloze,
# # "lambada_standard_cloze": lambada_cloze.LambadaStandardCloze,
# # # multilingual lambada
# # **lambada_multilingual.construct_tasks(),
# "wikitext": wikitext.WikiText,
# # # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# # # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
# # "piqa": piqa.PiQA,
# # "prost": prost.PROST,
# # "mc_taco": mc_taco.MCTACO,
# # # Science related
# # "pubmedqa": pubmedqa.Pubmed_QA,
# # "sciq": sciq.SciQ,
# # "qasper": qasper.QASPER,
# # "qa4mre_2011": qa4mre.QA4MRE_2011,
# # "qa4mre_2012": qa4mre.QA4MRE_2012,
# # "qa4mre_2013": qa4mre.QA4MRE_2013,
# # "triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy,
# "arc_challenge": arc.ARCChallenge,
# # # "quac": quac.QuAC, # not implemented yet
# # "logiqa": logiqa.LogiQA,
# # "hellaswag": hellaswag.HellaSwag,
# # "swag": swag.SWAG,
# # "openbookqa": openbookqa.OpenBookQA,
# # "squad2": squad.SQuAD2,
# # "race": race.RACE,
# # # "naturalqs": naturalqs.NaturalQs, # not implemented yet
# # "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
# # "headqa_es": headqa.HeadQAEs,
# # "headqa_en": headqa.HeadQAEn,
# # "mathqa": mathqa.MathQA,
# # "webqs": webqs.WebQs,
# # "wsc273": wsc273.WinogradSchemaChallenge273,
# # "winogrande": winogrande.Winogrande,
# # "anli_r1": anli.ANLIRound1,
# # "anli_r2": anli.ANLIRound2,
# # "anli_r3": anli.ANLIRound3,
# # "ethics_cm": hendrycks_ethics.EthicsCM,
# # "ethics_deontology": hendrycks_ethics.EthicsDeontology,
# # "ethics_justice": hendrycks_ethics.EthicsJustice,
# # "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
# # "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
# # "ethics_virtue": hendrycks_ethics.EthicsVirtue,
# # "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
# # "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# # # dialogue
# # "mutual": mutual.MuTual,
# # "mutual_plus": mutual.MuTualPlus,
# # # math
# # "math_algebra": hendrycks_math.MathAlgebra,
# # "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
# # "math_geometry": hendrycks_math.MathGeometry,
# # "math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
# # "math_num_theory": hendrycks_math.MathNumberTheory,
# # "math_prealgebra": hendrycks_math.MathPrealgebra,
# # "math_precalc": hendrycks_math.MathPrecalculus,
# # "math_asdiv": asdiv.Asdiv,
# "gsm8k": gsm8k.GradeSchoolMath8K,
# # # arithmetic
# # "arithmetic_2da": arithmetic.Arithmetic2DPlus,
# # "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
# # "arithmetic_3da": arithmetic.Arithmetic3DPlus,
# # "arithmetic_3ds": arithmetic.Arithmetic3DMinus,
# # "arithmetic_4da": arithmetic.Arithmetic4DPlus,
# # "arithmetic_4ds": arithmetic.Arithmetic4DMinus,
# # "arithmetic_5da": arithmetic.Arithmetic5DPlus,
# # "arithmetic_5ds": arithmetic.Arithmetic5DMinus,
# # "arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
# # "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# # # TODO Perhaps make these groups of tasks
# # # e.g. anli, arithmetic, openai_translations, harness_translations
# # # hendrycksTest (57 tasks)
# # **hendrycks_test.create_all_tasks(),
# # # e.g. wmt14-fr-en
# # **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# # # chef's selection, mostly wmt20
# # **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# # # Word Scrambling and Manipulation Tasks
# # "anagrams1": unscramble.Anagrams1,
# # "anagrams2": unscramble.Anagrams2,
# # "cycle_letters": unscramble.CycleLetters,
# # "random_insertion": unscramble.RandomInsertion,
# # "reversed_words": unscramble.ReversedWords,
# # # Pile
# # "pile_arxiv": pile.PileArxiv,
# # "pile_books3": pile.PileBooks3,
# # "pile_bookcorpus2": pile.PileBookCorpus2,
# # "pile_dm-mathematics": pile.PileDmMathematics,
# # "pile_enron": pile.PileEnron,
# # "pile_europarl": pile.PileEuroparl,
# # "pile_freelaw": pile.PileFreeLaw,
# # "pile_github": pile.PileGithub,
# # "pile_gutenberg": pile.PileGutenberg,
# # "pile_hackernews": pile.PileHackernews,
# # "pile_nih-exporter": pile.PileNIHExporter,
# # "pile_opensubtitles": pile.PileOpenSubtitles,
# # "pile_openwebtext2": pile.PileOpenWebText2,
# # "pile_philpapers": pile.PilePhilPapers,
# # "pile_pile-cc": pile.PilePileCc,
# # "pile_pubmed-abstracts": pile.PilePubmedAbstracts,
# # "pile_pubmed-central": pile.PilePubmedCentral,
# # "pile_stackexchange": pile.PileStackExchange,
# # "pile_uspto": pile.PileUspto,
# # "pile_ubuntu-irc": pile.PileUbuntuIrc,
# # "pile_wikipedia": pile.PileWikipedia,
# # "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# # # BLiMP
# # "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
# # "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
# # "blimp_anaphor_number_agreement": blimp.BlimpAnaphorNumberAgreement,
# # "blimp_animate_subject_passive": blimp.BlimpAnimateSubjectPassive,
# # "blimp_animate_subject_trans": blimp.BlimpAnimateSubjectTrans,
# # "blimp_causative": blimp.BlimpCausative,
# # "blimp_complex_NP_island": blimp.BlimpComplex_NPIsland,
# # "blimp_coordinate_structure_constraint_complex_left_branch": blimp.BlimpCoordinateStructureConstraintComplexLeftBranch,
# # "blimp_coordinate_structure_constraint_object_extraction": blimp.BlimpCoordinateStructureConstraintObjectExtraction,
# # "blimp_determiner_noun_agreement_1": blimp.BlimpDeterminerNounAgreement_1,
# # "blimp_determiner_noun_agreement_2": blimp.BlimpDeterminerNounAgreement_2,
# # "blimp_determiner_noun_agreement_irregular_1": blimp.BlimpDeterminerNounAgreementIrregular_1,
# # "blimp_determiner_noun_agreement_irregular_2": blimp.BlimpDeterminerNounAgreementIrregular_2,
# # "blimp_determiner_noun_agreement_with_adj_2": blimp.BlimpDeterminerNounAgreementWithAdj_2,
# # "blimp_determiner_noun_agreement_with_adj_irregular_1": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_1,
# # "blimp_determiner_noun_agreement_with_adj_irregular_2": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_2,
# # "blimp_determiner_noun_agreement_with_adjective_1": blimp.BlimpDeterminerNounAgreementWithAdjective_1,
# # "blimp_distractor_agreement_relational_noun": blimp.BlimpDistractorAgreementRelationalNoun,
# # "blimp_distractor_agreement_relative_clause": blimp.BlimpDistractorAgreementRelativeClause,
# # "blimp_drop_argument": blimp.BlimpDropArgument,
# # "blimp_ellipsis_n_bar_1": blimp.BlimpEllipsisNBar_1,
# # "blimp_ellipsis_n_bar_2": blimp.BlimpEllipsisNBar_2,
# # "blimp_existential_there_object_raising": blimp.BlimpExistentialThereObjectRaising,
# # "blimp_existential_there_quantifiers_1": blimp.BlimpExistentialThereQuantifiers_1,
# # "blimp_existential_there_quantifiers_2": blimp.BlimpExistentialThereQuantifiers_2,
# # "blimp_existential_there_subject_raising": blimp.BlimpExistentialThereSubjectRaising,
# # "blimp_expletive_it_object_raising": blimp.BlimpExpletiveItObjectRaising,
# # "blimp_inchoative": blimp.BlimpInchoative,
# # "blimp_intransitive": blimp.BlimpIntransitive,
# # "blimp_irregular_past_participle_adjectives": blimp.BlimpIrregularPastParticipleAdjectives,
# # "blimp_irregular_past_participle_verbs": blimp.BlimpIrregularPastParticipleVerbs,
# # "blimp_irregular_plural_subject_verb_agreement_1": blimp.BlimpIrregularPluralSubjectVerbAgreement_1,
# # "blimp_irregular_plural_subject_verb_agreement_2": blimp.BlimpIrregularPluralSubjectVerbAgreement_2,
# # "blimp_left_branch_island_echo_question": blimp.BlimpLeftBranchIslandEchoQuestion,
# # "blimp_left_branch_island_simple_question": blimp.BlimpLeftBranchIslandSimpleQuestion,
# # "blimp_matrix_question_npi_licensor_present": blimp.BlimpMatrixQuestionNpiLicensorPresent,
# # "blimp_npi_present_1": blimp.BlimpNpiPresent_1,
# # "blimp_npi_present_2": blimp.BlimpNpiPresent_2,
# # "blimp_only_npi_licensor_present": blimp.BlimpOnlyNpiLicensorPresent,
# # "blimp_only_npi_scope": blimp.BlimpOnlyNpiScope,
# # "blimp_passive_1": blimp.BlimpPassive_1,
# # "blimp_passive_2": blimp.BlimpPassive_2,
# # "blimp_principle_A_c_command": blimp.BlimpPrinciple_ACCommand,
# # "blimp_principle_A_case_1": blimp.BlimpPrinciple_ACase_1,
# # "blimp_principle_A_case_2": blimp.BlimpPrinciple_ACase_2,
# # "blimp_principle_A_domain_1": blimp.BlimpPrinciple_ADomain_1,
# # "blimp_principle_A_domain_2": blimp.BlimpPrinciple_ADomain_2,
# # "blimp_principle_A_domain_3": blimp.BlimpPrinciple_ADomain_3,
# # "blimp_principle_A_reconstruction": blimp.BlimpPrinciple_AReconstruction,
# # "blimp_regular_plural_subject_verb_agreement_1": blimp.BlimpRegularPluralSubjectVerbAgreement_1,
# # "blimp_regular_plural_subject_verb_agreement_2": blimp.BlimpRegularPluralSubjectVerbAgreement_2,
# # "blimp_sentential_negation_npi_licensor_present": blimp.BlimpSententialNegationNpiLicensorPresent,
# # "blimp_sentential_negation_npi_scope": blimp.BlimpSententialNegationNpiScope,
# # "blimp_sentential_subject_island": blimp.BlimpSententialSubjectIsland,
# # "blimp_superlative_quantifiers_1": blimp.BlimpSuperlativeQuantifiers_1,
# # "blimp_superlative_quantifiers_2": blimp.BlimpSuperlativeQuantifiers_2,
# # "blimp_tough_vs_raising_1": blimp.BlimpToughVsRaising_1,
# # "blimp_tough_vs_raising_2": blimp.BlimpToughVsRaising_2,
# # "blimp_transitive": blimp.BlimpTransitive,
# # "blimp_wh_island": blimp.BlimpWhIsland,
# # "blimp_wh_questions_object_gap": blimp.BlimpWhQuestionsObjectGap,
# # "blimp_wh_questions_subject_gap": blimp.BlimpWhQuestionsSubjectGap,
# # "blimp_wh_questions_subject_gap_long_distance": blimp.BlimpWhQuestionsSubjectGapLongDistance,
# # "blimp_wh_vs_that_no_gap": blimp.BlimpWhVsThatNoGap,
# # "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
# # "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
# # "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# # "toxigen": toxigen.ToxiGen,
# # "crows_pairs_english": crowspairs.CrowsPairsEnglish,
# # "crows_pairs_english_race_color": crowspairs.CrowsPairsEnglishRaceColor,
# # "crows_pairs_english_socioeconomic": crowspairs.CrowsPairsEnglishSocioeconomic,
# # "crows_pairs_english_gender": crowspairs.CrowsPairsEnglishGender,
# # "crows_pairs_english_age": crowspairs.CrowsPairsEnglishAge,
# # "crows_pairs_english_religion": crowspairs.CrowsPairsEnglishReligion,
# # "crows_pairs_english_disability": crowspairs.CrowsPairsEnglishDisability,
# # "crows_pairs_english_sexual_orientation": crowspairs.CrowsPairsEnglishSexualOrientation,
# # "crows_pairs_english_nationality": crowspairs.CrowsPairsEnglishNationality,
# # "crows_pairs_english_physical_appearance": crowspairs.CrowsPairsEnglishPhysicalAppearance,
# # "crows_pairs_english_autre": crowspairs.CrowsPairsEnglishAutre,
# # "crows_pairs_french": crowspairs.CrowsPairsFrench,
# # "crows_pairs_french_race_color": crowspairs.CrowsPairsFrenchRaceColor,
# # "crows_pairs_french_socioeconomic": crowspairs.CrowsPairsFrenchSocioeconomic,
# # "crows_pairs_french_gender": crowspairs.CrowsPairsFrenchGender,
# # "crows_pairs_french_age": crowspairs.CrowsPairsFrenchAge,
# # "crows_pairs_french_religion": crowspairs.CrowsPairsFrenchReligion,
# # "crows_pairs_french_disability": crowspairs.CrowsPairsFrenchDisability,
# # "crows_pairs_french_sexual_orientation": crowspairs.CrowsPairsFrenchSexualOrientation,
# # "crows_pairs_french_nationality": crowspairs.CrowsPairsFrenchNationality,
# # "crows_pairs_french_physical_appearance": crowspairs.CrowsPairsFrenchPhysicalAppearance,
# # "crows_pairs_french_autre": crowspairs.CrowsPairsFrenchAutre,
# # Requires manual download of data.
# # "storycloze_2016": storycloze.StoryCloze2016,
# # "storycloze_2018": storycloze.StoryCloze2018,
# # "sat": sat.SATAnalogies,
# }
ALL_TASKS = sorted(list(TASK_REGISTRY))
# ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name):
try:
return TASK_REGISTRY[task_name]
except KeyError:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
# def get_task(task_name):
# try:
# return TASK_REGISTRY[task_name]
# except KeyError:
# print("Available tasks:")
# pprint(TASK_REGISTRY)
# raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# def get_task_name_from_object(task_object):
# for name, class_ in TASK_REGISTRY.items():
# if class_ is task_object:
# return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
# # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
# return (
# task_object.EVAL_HARNESS_NAME
# if hasattr(task_object, "EVAL_HARNESS_NAME")
# else type(task_object).__name__
# )
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
# def get_task_name_from_config(task_config):
# return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
def get_task_dict(task_name_list: List[Union[str, dict, api.task.Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
task_name_dict = {
task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name})
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_config_dict = {
get_task_name_from_config(task_config): api.task.ConfigurableTask(
config=task_config
)
for task_config in task_name_list
if isinstance(task_config, dict)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list
if isinstance(task_object, api.task.Task)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {
**task_name_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
\ No newline at end of file
# def get_task_dict(task_name_list: List[Union[str, dict, api.task.Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
# task_name_dict = {
# task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name})
# for task_name in task_name_list
# if isinstance(task_name, str)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): api.task.ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, api.task.Task)
# }
# assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
# return {
# **task_name_dict,
# **task_name_from_config_dict,
# **task_name_from_object_dict,
# }
\ No newline at end of file
......@@ -12,7 +12,7 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc
"""
from lm_eval.api.task import MultipleChoiceTask
from lm_eval.api.task import MultipleChoiceTask, register_task
from lm_eval.prompts import get_prompt
from lm_eval import utils
......@@ -28,7 +28,7 @@ _CITATION = """
}
"""
@register_task("arc_easy")
class ARCEasy(MultipleChoiceTask):
VERSION = "2.0"
DATASET_PATH = "ai2_arc"
......@@ -80,6 +80,7 @@ class ARCEasy(MultipleChoiceTask):
return doc["query"]
@register_task("arc_challenge")
class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Challenge"
......@@ -17,7 +17,7 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import re
from lm_eval.api.task import Task
from lm_eval.api.task import Task, register_task
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean
......@@ -41,6 +41,7 @@ ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
@register_task("gsm8k")
class GradeSchoolMath8K(Task):
VERSION = 0
DATASET_PATH = "gsm8k"
......
......@@ -12,7 +12,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from lm_eval.api.task import Task
from lm_eval.api.task import Task, register_task
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity
......@@ -75,6 +75,7 @@ class LambadaBase(Task):
return {"ppl": False, "acc": True}
@register_task("lambada_standard")
class LambadaStandard(LambadaBase):
"""The LAMBADA task using the standard original LAMBADA dataset."""
......@@ -90,7 +91,7 @@ class LambadaStandard(LambadaBase):
def has_test_docs(self):
return True
@register_task("lambada_openai")
class LambadaOpenAI(LambadaBase):
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.
......
......@@ -10,7 +10,7 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
"""
import re
from lm_eval.api.task import PerplexityTask
from lm_eval.api.task import PerplexityTask, register_task
_CITATION = """
......@@ -58,7 +58,7 @@ def wikitext_detokenizer(string):
return string
@register_task("wikitext")
class WikiText(PerplexityTask):
VERSION = "2.0"
DATASET_PATH = "EleutherAI/wikitext_document_level"
......
......@@ -5,14 +5,17 @@ import fnmatch
import yaml
from lm_eval import tasks, evaluator
from lm_eval.api.task import ConfigurableTask
# import lm_eval.api.task
from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY
logging.getLogger("openai").setLevel(logging.WARNING)
ALL_TASKS = sorted(list(TASK_REGISTRY))
class MultiChoice:
def __init__(self, choices):
self.choices = choices
print(f"{ALL_TASKS} is this")
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
......@@ -31,7 +34,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument("--tasks", default=None, choices=MultiChoice(ALL_TASKS))
parser.add_argument("--config", default=None)
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
......@@ -80,9 +83,9 @@ def main():
task_names.append(config)
else:
task_names = tasks.ALL_TASKS
task_names = ALL_TASKS
else:
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
task_names = pattern_match(args.tasks.split(","), ALL_TASKS)
print(f"Selected Tasks: {task_names}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment