Commit 9f1cb1e7 authored by lintangsutawika's avatar lintangsutawika
Browse files

merged with conflict resolved

parents 8f859cd2 0375b792
from lm_eval.logger import eval_logger
from promptsource.templates import DatasetTemplates
# TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both?
# Prompt library.
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
PROMPT_REGISTRY = {
"qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {question}\nA:"
"q-newline-a": "Q: {{question}}\nA:",
},
}
def get_prompt(prompt_id: str):
# unpack prompt name
try:
category_name, prompt_name = prompt_id.split(":")
except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
# unpack prompt name
category_name, prompt_name = prompt_id.split(":")
if subset_name is None:
dataset_full_name = dataset_name
else:
dataset_full_name = f"{dataset_name}-{subset_name}"
eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
if category_name == "promptsource":
try:
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(
dataset_name=dataset_name, subset_name=subset_name
)
except Exception:
raise ValueError(f"{dataset_name} and {subset_name} not found")
if prompt_name in prompts.all_template_names:
return prompts[prompt_name]
else:
raise ValueError(
f"{prompt_name} not in prompt list {prompts.all_template_names}"
)
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name]
except Exception:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
return PROMPT_REGISTRY[category_name][prompt_name]
\ No newline at end of file
from pprint import pprint
import os
from typing import List, Union
import sacrebleu
from lm_eval import api
# from . import superglue
# from . import glue
from . import arc
# from . import coqa
# from . import race
# from . import webqs
# from . import anli
# from . import wsc273
# from . import winogrande
# from . import quac
# from . import hellaswag
# from . import swag
# from . import openbookqa
# from . import squad
# from . import naturalqs
# from . import sat
# from . import arithmetic
from . import lambada
# from . import piqa
# from . import prost
# from . import mc_taco
# from . import triviaqa
# from . import pubmedqa
# from . import sciq
# from . import qasper
# from . import qa4mre
# from . import translation
# from . import headqa
# from . import mathqa
# from . import hendrycks_ethics
# from . import drop
# from . import unscramble
# from . import logiqa
# from . import hendrycks_test
# from . import hendrycks_math
# from . import cbt
# from . import lambada_cloze
from . import pile
from . import wikitext
# from . import lambada_multilingual
# from . import mutual
# from . import truthfulqa
# from . import blimp
# from . import asdiv
from . import gsm8k
# from . import storycloze
# from . import toxigen
# from . import crowspairs
# ########################################
# # Translation tasks
# ########################################
# # 6 total
# gpt3_translation_benchmarks = {
# "wmt14": ["en-fr", "fr-en"], # French
# "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
# }
# # 28 total
# selected_translation_benchmarks = {
# **gpt3_translation_benchmarks,
# "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
# "iwslt17": ["en-ar", "ar-en"], # Arabic
# }
# # 319 total
# all_translation_benchmarks = {
# ts: sacrebleu.get_langpairs_for_testset(ts)
# for ts in sacrebleu.get_available_testsets()
# }
# ########################################
# # All tasks
# ########################################
# TASK_REGISTRY = {
# # GLUE
# # "cola": glue.CoLA,
# # "mnli": glue.MNLI,
# # "mnli_mismatched": glue.MNLIMismatched,
# # "mrpc": glue.MRPC,
# # "rte": glue.RTE,
# # "qnli": glue.QNLI,
# # "qqp": glue.QQP,
# # # "stsb": glue.STSB, # not implemented yet
# # "sst": glue.SST,
# # "wnli": glue.WNLI,
# # # SuperGLUE
# # "boolq": superglue.BoolQ,
# # "cb": superglue.CommitmentBank,
# # "copa": superglue.Copa,
# # "multirc": superglue.MultiRC,
# # "record": superglue.ReCoRD,
# # "wic": superglue.WordsInContext,
# # "wsc": superglue.SGWinogradSchemaChallenge,
# # # Order by benchmark/genre?
# # "coqa": coqa.CoQA,
# # "drop": drop.DROP,
# "lambada_openai": lambada.LambadaOpenAI,
# "lambada_standard": lambada.LambadaStandard,
# # "lambada_openai_cloze": lambada_cloze.LambadaOpenAICloze,
# # "lambada_standard_cloze": lambada_cloze.LambadaStandardCloze,
# # # multilingual lambada
# # **lambada_multilingual.construct_tasks(),
# "wikitext": wikitext.WikiText,
# # # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# # # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
# # "piqa": piqa.PiQA,
# # "prost": prost.PROST,
# # "mc_taco": mc_taco.MCTACO,
# # # Science related
# # "pubmedqa": pubmedqa.Pubmed_QA,
# # "sciq": sciq.SciQ,
# # "qasper": qasper.QASPER,
# # "qa4mre_2011": qa4mre.QA4MRE_2011,
# # "qa4mre_2012": qa4mre.QA4MRE_2012,
# # "qa4mre_2013": qa4mre.QA4MRE_2013,
# # "triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy,
# "arc_challenge": arc.ARCChallenge,
# # # "quac": quac.QuAC, # not implemented yet
# # "logiqa": logiqa.LogiQA,
# # "hellaswag": hellaswag.HellaSwag,
# # "swag": swag.SWAG,
# # "openbookqa": openbookqa.OpenBookQA,
# # "squad2": squad.SQuAD2,
# # "race": race.RACE,
# # # "naturalqs": naturalqs.NaturalQs, # not implemented yet
# # "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
# # "headqa_es": headqa.HeadQAEs,
# # "headqa_en": headqa.HeadQAEn,
# # "mathqa": mathqa.MathQA,
# # "webqs": webqs.WebQs,
# # "wsc273": wsc273.WinogradSchemaChallenge273,
# # "winogrande": winogrande.Winogrande,
# # "anli_r1": anli.ANLIRound1,
# # "anli_r2": anli.ANLIRound2,
# # "anli_r3": anli.ANLIRound3,
# # "ethics_cm": hendrycks_ethics.EthicsCM,
# # "ethics_deontology": hendrycks_ethics.EthicsDeontology,
# # "ethics_justice": hendrycks_ethics.EthicsJustice,
# # "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
# # "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
# # "ethics_virtue": hendrycks_ethics.EthicsVirtue,
# # "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
# # "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# # # dialogue
# # "mutual": mutual.MuTual,
# # "mutual_plus": mutual.MuTualPlus,
# # # math
# # "math_algebra": hendrycks_math.MathAlgebra,
# # "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
# # "math_geometry": hendrycks_math.MathGeometry,
# # "math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
# # "math_num_theory": hendrycks_math.MathNumberTheory,
# # "math_prealgebra": hendrycks_math.MathPrealgebra,
# # "math_precalc": hendrycks_math.MathPrecalculus,
# # "math_asdiv": asdiv.Asdiv,
# "gsm8k": gsm8k.GradeSchoolMath8K,
# # # arithmetic
# # "arithmetic_2da": arithmetic.Arithmetic2DPlus,
# # "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
# # "arithmetic_3da": arithmetic.Arithmetic3DPlus,
# # "arithmetic_3ds": arithmetic.Arithmetic3DMinus,
# # "arithmetic_4da": arithmetic.Arithmetic4DPlus,
# # "arithmetic_4ds": arithmetic.Arithmetic4DMinus,
# # "arithmetic_5da": arithmetic.Arithmetic5DPlus,
# # "arithmetic_5ds": arithmetic.Arithmetic5DMinus,
# # "arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
# # "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# # # TODO Perhaps make these groups of tasks
# # # e.g. anli, arithmetic, openai_translations, harness_translations
# # # hendrycksTest (57 tasks)
# # **hendrycks_test.create_all_tasks(),
# # # e.g. wmt14-fr-en
# # **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# # # chef's selection, mostly wmt20
# # **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# # # Word Scrambling and Manipulation Tasks
# # "anagrams1": unscramble.Anagrams1,
# # "anagrams2": unscramble.Anagrams2,
# # "cycle_letters": unscramble.CycleLetters,
# # "random_insertion": unscramble.RandomInsertion,
# # "reversed_words": unscramble.ReversedWords,
# # # Pile
# # "pile_arxiv": pile.PileArxiv,
# # "pile_books3": pile.PileBooks3,
# # "pile_bookcorpus2": pile.PileBookCorpus2,
# # "pile_dm-mathematics": pile.PileDmMathematics,
# # "pile_enron": pile.PileEnron,
# # "pile_europarl": pile.PileEuroparl,
# # "pile_freelaw": pile.PileFreeLaw,
# # "pile_github": pile.PileGithub,
# # "pile_gutenberg": pile.PileGutenberg,
# # "pile_hackernews": pile.PileHackernews,
# # "pile_nih-exporter": pile.PileNIHExporter,
# # "pile_opensubtitles": pile.PileOpenSubtitles,
# # "pile_openwebtext2": pile.PileOpenWebText2,
# # "pile_philpapers": pile.PilePhilPapers,
# # "pile_pile-cc": pile.PilePileCc,
# # "pile_pubmed-abstracts": pile.PilePubmedAbstracts,
# # "pile_pubmed-central": pile.PilePubmedCentral,
# # "pile_stackexchange": pile.PileStackExchange,
# # "pile_uspto": pile.PileUspto,
# # "pile_ubuntu-irc": pile.PileUbuntuIrc,
# # "pile_wikipedia": pile.PileWikipedia,
# # "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# # # BLiMP
# # "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
# # "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
# # "blimp_anaphor_number_agreement": blimp.BlimpAnaphorNumberAgreement,
# # "blimp_animate_subject_passive": blimp.BlimpAnimateSubjectPassive,
# # "blimp_animate_subject_trans": blimp.BlimpAnimateSubjectTrans,
# # "blimp_causative": blimp.BlimpCausative,
# # "blimp_complex_NP_island": blimp.BlimpComplex_NPIsland,
# # "blimp_coordinate_structure_constraint_complex_left_branch": blimp.BlimpCoordinateStructureConstraintComplexLeftBranch,
# # "blimp_coordinate_structure_constraint_object_extraction": blimp.BlimpCoordinateStructureConstraintObjectExtraction,
# # "blimp_determiner_noun_agreement_1": blimp.BlimpDeterminerNounAgreement_1,
# # "blimp_determiner_noun_agreement_2": blimp.BlimpDeterminerNounAgreement_2,
# # "blimp_determiner_noun_agreement_irregular_1": blimp.BlimpDeterminerNounAgreementIrregular_1,
# # "blimp_determiner_noun_agreement_irregular_2": blimp.BlimpDeterminerNounAgreementIrregular_2,
# # "blimp_determiner_noun_agreement_with_adj_2": blimp.BlimpDeterminerNounAgreementWithAdj_2,
# # "blimp_determiner_noun_agreement_with_adj_irregular_1": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_1,
# # "blimp_determiner_noun_agreement_with_adj_irregular_2": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_2,
# # "blimp_determiner_noun_agreement_with_adjective_1": blimp.BlimpDeterminerNounAgreementWithAdjective_1,
# # "blimp_distractor_agreement_relational_noun": blimp.BlimpDistractorAgreementRelationalNoun,
# # "blimp_distractor_agreement_relative_clause": blimp.BlimpDistractorAgreementRelativeClause,
# # "blimp_drop_argument": blimp.BlimpDropArgument,
# # "blimp_ellipsis_n_bar_1": blimp.BlimpEllipsisNBar_1,
# # "blimp_ellipsis_n_bar_2": blimp.BlimpEllipsisNBar_2,
# # "blimp_existential_there_object_raising": blimp.BlimpExistentialThereObjectRaising,
# # "blimp_existential_there_quantifiers_1": blimp.BlimpExistentialThereQuantifiers_1,
# # "blimp_existential_there_quantifiers_2": blimp.BlimpExistentialThereQuantifiers_2,
# # "blimp_existential_there_subject_raising": blimp.BlimpExistentialThereSubjectRaising,
# # "blimp_expletive_it_object_raising": blimp.BlimpExpletiveItObjectRaising,
# # "blimp_inchoative": blimp.BlimpInchoative,
# # "blimp_intransitive": blimp.BlimpIntransitive,
# # "blimp_irregular_past_participle_adjectives": blimp.BlimpIrregularPastParticipleAdjectives,
# # "blimp_irregular_past_participle_verbs": blimp.BlimpIrregularPastParticipleVerbs,
# # "blimp_irregular_plural_subject_verb_agreement_1": blimp.BlimpIrregularPluralSubjectVerbAgreement_1,
# # "blimp_irregular_plural_subject_verb_agreement_2": blimp.BlimpIrregularPluralSubjectVerbAgreement_2,
# # "blimp_left_branch_island_echo_question": blimp.BlimpLeftBranchIslandEchoQuestion,
# # "blimp_left_branch_island_simple_question": blimp.BlimpLeftBranchIslandSimpleQuestion,
# # "blimp_matrix_question_npi_licensor_present": blimp.BlimpMatrixQuestionNpiLicensorPresent,
# # "blimp_npi_present_1": blimp.BlimpNpiPresent_1,
# # "blimp_npi_present_2": blimp.BlimpNpiPresent_2,
# # "blimp_only_npi_licensor_present": blimp.BlimpOnlyNpiLicensorPresent,
# # "blimp_only_npi_scope": blimp.BlimpOnlyNpiScope,
# # "blimp_passive_1": blimp.BlimpPassive_1,
# # "blimp_passive_2": blimp.BlimpPassive_2,
# # "blimp_principle_A_c_command": blimp.BlimpPrinciple_ACCommand,
# # "blimp_principle_A_case_1": blimp.BlimpPrinciple_ACase_1,
# # "blimp_principle_A_case_2": blimp.BlimpPrinciple_ACase_2,
# # "blimp_principle_A_domain_1": blimp.BlimpPrinciple_ADomain_1,
# # "blimp_principle_A_domain_2": blimp.BlimpPrinciple_ADomain_2,
# # "blimp_principle_A_domain_3": blimp.BlimpPrinciple_ADomain_3,
# # "blimp_principle_A_reconstruction": blimp.BlimpPrinciple_AReconstruction,
# # "blimp_regular_plural_subject_verb_agreement_1": blimp.BlimpRegularPluralSubjectVerbAgreement_1,
# # "blimp_regular_plural_subject_verb_agreement_2": blimp.BlimpRegularPluralSubjectVerbAgreement_2,
# # "blimp_sentential_negation_npi_licensor_present": blimp.BlimpSententialNegationNpiLicensorPresent,
# # "blimp_sentential_negation_npi_scope": blimp.BlimpSententialNegationNpiScope,
# # "blimp_sentential_subject_island": blimp.BlimpSententialSubjectIsland,
# # "blimp_superlative_quantifiers_1": blimp.BlimpSuperlativeQuantifiers_1,
# # "blimp_superlative_quantifiers_2": blimp.BlimpSuperlativeQuantifiers_2,
# # "blimp_tough_vs_raising_1": blimp.BlimpToughVsRaising_1,
# # "blimp_tough_vs_raising_2": blimp.BlimpToughVsRaising_2,
# # "blimp_transitive": blimp.BlimpTransitive,
# # "blimp_wh_island": blimp.BlimpWhIsland,
# # "blimp_wh_questions_object_gap": blimp.BlimpWhQuestionsObjectGap,
# # "blimp_wh_questions_subject_gap": blimp.BlimpWhQuestionsSubjectGap,
# # "blimp_wh_questions_subject_gap_long_distance": blimp.BlimpWhQuestionsSubjectGapLongDistance,
# # "blimp_wh_vs_that_no_gap": blimp.BlimpWhVsThatNoGap,
# # "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
# # "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
# # "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# # "toxigen": toxigen.ToxiGen,
# # "crows_pairs_english": crowspairs.CrowsPairsEnglish,
# # "crows_pairs_english_race_color": crowspairs.CrowsPairsEnglishRaceColor,
# # "crows_pairs_english_socioeconomic": crowspairs.CrowsPairsEnglishSocioeconomic,
# # "crows_pairs_english_gender": crowspairs.CrowsPairsEnglishGender,
# # "crows_pairs_english_age": crowspairs.CrowsPairsEnglishAge,
# # "crows_pairs_english_religion": crowspairs.CrowsPairsEnglishReligion,
# # "crows_pairs_english_disability": crowspairs.CrowsPairsEnglishDisability,
# # "crows_pairs_english_sexual_orientation": crowspairs.CrowsPairsEnglishSexualOrientation,
# # "crows_pairs_english_nationality": crowspairs.CrowsPairsEnglishNationality,
# # "crows_pairs_english_physical_appearance": crowspairs.CrowsPairsEnglishPhysicalAppearance,
# # "crows_pairs_english_autre": crowspairs.CrowsPairsEnglishAutre,
# # "crows_pairs_french": crowspairs.CrowsPairsFrench,
# # "crows_pairs_french_race_color": crowspairs.CrowsPairsFrenchRaceColor,
# # "crows_pairs_french_socioeconomic": crowspairs.CrowsPairsFrenchSocioeconomic,
# # "crows_pairs_french_gender": crowspairs.CrowsPairsFrenchGender,
# # "crows_pairs_french_age": crowspairs.CrowsPairsFrenchAge,
# # "crows_pairs_french_religion": crowspairs.CrowsPairsFrenchReligion,
# # "crows_pairs_french_disability": crowspairs.CrowsPairsFrenchDisability,
# # "crows_pairs_french_sexual_orientation": crowspairs.CrowsPairsFrenchSexualOrientation,
# # "crows_pairs_french_nationality": crowspairs.CrowsPairsFrenchNationality,
# # "crows_pairs_french_physical_appearance": crowspairs.CrowsPairsFrenchPhysicalAppearance,
# # "crows_pairs_french_autre": crowspairs.CrowsPairsFrenchAutre,
# # Requires manual download of data.
# # "storycloze_2016": storycloze.StoryCloze2016,
# # "storycloze_2018": storycloze.StoryCloze2018,
# # "sat": sat.SATAnalogies,
# }
# ALL_TASKS = sorted(list(TASK_REGISTRY))
# def get_task(task_name):
# try:
# return TASK_REGISTRY[task_name]
# except KeyError:
# print("Available tasks:")
# pprint(TASK_REGISTRY)
# raise KeyError(f"Missing task {task_name}")
# def get_task_name_from_object(task_object):
# for name, class_ in TASK_REGISTRY.items():
# if class_ is task_object:
# return name
# # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
# return (
# task_object.EVAL_HARNESS_NAME
# if hasattr(task_object, "EVAL_HARNESS_NAME")
# else type(task_object).__name__
# )
# def get_task_name_from_config(task_config):
# return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
# def get_task_dict(task_name_list: List[Union[str, dict, api.task.Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
# task_name_dict = {
# task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name})
# for task_name in task_name_list
# if isinstance(task_name, str)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): api.task.ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, api.task.Task)
# }
# assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
# return {
# **task_name_dict,
# **task_name_from_config_dict,
# **task_name_from_object_dict,
# }
\ No newline at end of file
from .arc import *
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.register import (
register_task,
register_group,
task_registry,
group_registry,
)
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == []) and (len(file_list) > 0):
for file in file_list:
if "yaml" in file:
yaml_path = os.path.join(root, file)
try:
config = utils.load_yaml_config(yaml_path)
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
task_name = "{}:{}".format(
get_task_name_from_config(config), config["task"]
)
register_task(task_name)(SubClass)
if "group" in config:
for group in config["group"]:
register_group(group)(SubClass)
except Exception as err:
print(f"Unexpected {err=}, {type(err)=}")
TASK_REGISTRY = task_registry
GROUP_REGISTRY = group_registry
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
def get_task(task_name, config):
try:
return TASK_REGISTRY[task_name](config=config)
except KeyError:
eval_logger.info("Available tasks:")
eval_logger.info(ALL_TASKS)
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
config = {**kwargs}
task_name_from_registry_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
for task_element in task_name_list:
if isinstance(task_element, str):
if task_element in GROUP_REGISTRY:
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(task_name=task_name, config=config),
}
else:
task_name = task_element
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(task_name=task_element, config=config),
}
elif isinstance(task_element, dict):
task_element.update(config)
task_name_from_config_dict = {
**task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask(
config=task_element
),
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
}
# task_name_from_registry_dict = {
# task_name: get_task(
# task_name=task_name,
# task_config=config
# )
# for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name]
# if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# # TODO: Do we still need this?
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, Task)
# }
assert set(task_name_from_registry_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
)
return {
**task_name_from_registry_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
......@@ -12,11 +12,11 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc
"""
from lm_eval.api.task import MultipleChoiceTask, register_task
from lm_eval.prompts import get_prompt
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.task import MultipleChoiceTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@article{Clark2018ThinkYH,
......@@ -28,6 +28,8 @@ _CITATION = """
}
"""
@register_group("arc")
@register_task("arc_easy")
class ARCEasy(MultipleChoiceTask):
VERSION = "2.0"
......@@ -80,6 +82,7 @@ class ARCEasy(MultipleChoiceTask):
return doc["query"]
@register_group("arc")
@register_task("arc_challenge")
class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc"
......
# ARC
Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge
https://arxiv.org/pdf/1803.05457.pdf
The ARC dataset consists of 7,787 science exam questions drawn from a variety
of sources, including science questions provided under license by a research
partner affiliated with AI2. These are text-only, English language exam questions
that span several grade levels as indicated in the files. Each question has a
multiple choice structure (typically 4 answer options). The questions are sorted
into a Challenge Set of 2,590 “hard” questions (those that both a retrieval and
a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questions.
Homepage: https://allenai.org/data/arc
### Citation
```
@article{Clark2018ThinkYH,
title={Think you have Solved Question Answering? Try ARC, the AI2 Reasoning Challenge},
author={Peter Clark and Isaac Cowhey and Oren Etzioni and Tushar Khot and Ashish Sabharwal and Carissa Schoenick and Oyvind Tafjord},
journal={ArXiv},
year={2018},
volume={abs/1803.05457}
}
```
group:
- arc_yaml
task: arc_challenge_yaml
dataset_path: ai2_arc
dataset_name: ARC-Challenge
output_type: multiple_choice
......@@ -6,11 +9,14 @@ validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
\ No newline at end of file
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
group:
- arc_yaml
task: arc_easy_yaml
dataset_path: ai2_arc
dataset_name: ARC-Easy
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
......@@ -17,13 +17,14 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math
"""
import re
from lm_eval.api.task import Task, register_task
from lm_eval.api.instance import Instance
from lm_eval import utils
from lm_eval.api.task import Task
from lm_eval.api.metrics import mean
from lm_eval.api.instance import Instance
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.api.register import register_task, register_group
_CITATION = """
@misc{cobbe2021training,
......@@ -88,7 +89,13 @@ class GradeSchoolMath8K(Task):
"""
# NOTE: The paper implements "verifiers" that assign a score to multiple
# solutions and output the highest ranked solution.
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, ["\n"]), idx=0, **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, ["\n"]),
idx=0,
**kwargs
)
# completion = rf.greedy_until(ctx, ["\n"])
# return completion
......
# "Training Verifiers to Solve Math Word Problems"
# https://arxiv.org/abs/2110.14168
# State-of-the-art language models can match human performance on many tasks, but
# they still struggle to robustly perform multi-step mathematical reasoning. To
# diagnose the failures of current models and support research, we introduce GSM8K,
# a dataset of 8.5K high quality linguistically diverse grade school math word problems.
# We find that even the largest transformer models fail to achieve high test performance,
# despite the conceptual simplicity of this problem distribution.
# NOTE: See the official implementation of the task:
# https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py
# for how to make use of the dataset's calculator annotations in your language
# model's sample/generation function.
# Homepage: https://github.com/openai/grade-school-math
# _CITATION = """
# @misc{cobbe2021training,
# title={Training Verifiers to Solve Math Word Problems},
# author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman},
# year={2021},
# eprint={2110.14168},
# archivePrefix={arXiv},
# primaryClass={cs.LG}
# }
# """
task: gsm8k_yaml
dataset_path: gsm8k
dataset_name: main
training_split: train
test_split: test
use_prompt: "qa-basic:question-newline-answer"
doc_to_target: "{{answer.split('### ')[-1]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
repeats: 4
# filter_list:
# - name: "get-answer"
# filter:
# - function: "regex"
# regex_pattern: "#### (\-?[0-9\.\,]+)"
......@@ -12,10 +12,11 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from lm_eval.api.task import Task, register_task
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity
from lm_eval.api.register import register_task, register_group
_CITATION = """
@misc{
......@@ -59,11 +60,18 @@ class LambadaBase(Task):
return " " + doc["text"].rsplit(" ", 1)[1]
def construct_requests(self, doc, ctx, **kwargs):
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(ctx, self.doc_to_target(doc)),
**kwargs
)
def process_results(self, doc, results):
# TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
results = results[0] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
results = results[
0
] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
ll, is_greedy = results
return {"ppl": ll, "acc": int(is_greedy)}
......@@ -91,6 +99,7 @@ class LambadaStandard(LambadaBase):
def has_test_docs(self):
return True
@register_task("lambada_openai")
class LambadaOpenAI(LambadaBase):
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
......
task: lambada_openai_yaml
dataset_path: EleutherAI/lambada_openai
dataset_name: default
output_type: loglikelihood
......
......@@ -10,8 +10,9 @@ math, computer science, and philosophy papers.
Homepage: https://pile.eleuther.ai/
"""
from lm_eval.api.task import PerplexityTask, register_task
from lm_eval.api.task import PerplexityTask
from lm_eval.api.register import register_task, register_group
_CITATION = """
@article{pile,
......@@ -34,7 +35,7 @@ class PilePerplexityTask(PerplexityTask):
def test_docs(self):
for doc in self.dataset["train"].select(range(100)):
yield doc
def has_validation_docs(self):
return False
......@@ -139,4 +140,4 @@ class PileWikipedia(PilePerplexityTask):
class PileYoutubeSubtitles(PilePerplexityTask):
DATASET_NAME = "pile_youtubesubtitles"
\ No newline at end of file
DATASET_NAME = "pile_youtubesubtitles"
# The Pile: An 800GB Dataset of Diverse Text for Language Modeling
# https://arxiv.org/pdf/2101.00027.pdf
# The Pile is a 825 GiB diverse, open source language modelling data set that consists
# of 22 smaller, high-quality datasets combined together. To score well on Pile
# BPB (bits per byte), a model must be able to understand many disparate domains
# including books, github repositories, webpages, chat logs, and medical, physics,
# math, computer science, and philosophy papers.
# Homepage: https://pile.eleuther.ai/
# _CITATION = """
# @article{pile,
# title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
# author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
# journal={arXiv preprint arXiv:2101.00027},
# year={2020}
# }
# """
names:
- pile_enron_yaml
dataset_path: EleutherAI/the_pile
dataset_name: enron_emails
output_type: loglikelihood_rolling
......@@ -16,4 +37,4 @@ metric_list:
higher_is_better: false
- metric: bits_per_byte
aggregation: bits_per_byte
higher_is_better: false
\ No newline at end of file
higher_is_better: false
dataset_path: gsm8k
dataset_name: main
group:
- super-glue-promptsource
task: "GPT-3 Style"
dataset_path: super_glue
dataset_name: boolq
training_split: train
test_split: test
doc_to_target: "{{answer.split('### ')[-1]}}"
use_prompt: "qa-basic:question-newline-answer"
validation_split: validation
use_prompt: "promptsource:GPT-3 Style"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
# filters: [
# ["regex", ["regex", "take_first"]]
# ]
\ No newline at end of file
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the previous passage"
use_prompt: "promptsource:based on the previous passage"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the following passage"
use_prompt: "promptsource:based on the following passage"
group:
- super-glue-promptsource
task: "GPT-3 style"
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypothesis}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
use_prompt: "promptsource:GPT-3 style"
metric_list:
- metric: exact_match
aggregation: mean
......
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "MNLI crowdsource"
use_prompt: "promptsource:MNLI crowdsource"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the previous passage"
use_prompt: "promptsource:based on the previous passage"
group:
- super-glue-t5-prompt
task: t5-prompt
reference: "From Raffel et. al. 2019"
dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
template_aliases: "{% set hypo = hypothesis %}"
doc_to_text: "Suppose {{premise}} Can we infer that \"{{hypo}}\"? Yes, no, or maybe?"
doc_to_target: "{% set answer_choices = ['Yes', 'No', 'Maybe'] %}{{answer_choices[label]}}"
doc_to_text: "cb hypothesis: {{hypothesis}} premise {{premise}}"
doc_to_target: "{% set answer_choices = ['entailment', 'contradiction', 'neutral'] %}{{answer_choices[label]}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
\ No newline at end of file
ignore_punctuation: true
group:
- super-glue-promptsource
task: "C1 or C2? premise, so/because…"
dataset_path: super_glue
dataset_name: copa
training_split: train
validation_split: validation
use_prompt: "promptsource:C1 or C2? premise, so/because…"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment