Commit 53a4817e authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'big-refactor' of...

Merge branch 'big-refactor' of https://github.com/EleutherAI/lm-evaluation-harness into p3-prompt-task
parents 41d16c2b f6b76f5d
...@@ -25,6 +25,7 @@ HIGHER_IS_BETTER_REGISTRY = { ...@@ -25,6 +25,7 @@ HIGHER_IS_BETTER_REGISTRY = {
"acc": True, "acc": True,
"acc_norm": True, "acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False, "word_perplexity": False,
"byte_perplexity": False, "byte_perplexity": False,
"bits_per_byte": False, "bits_per_byte": False,
......
...@@ -13,6 +13,7 @@ AGGREGATION_REGISTRY = {} ...@@ -13,6 +13,7 @@ AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = { METRIC_REGISTRY = {
"acc": None, "acc": None,
"acc_norm": None, "acc_norm": None,
"acc_mutual_info": None,
"word_perplexity": None, "word_perplexity": None,
"byte_perplexity": None, "byte_perplexity": None,
} }
......
...@@ -2,6 +2,7 @@ import abc ...@@ -2,6 +2,7 @@ import abc
from dataclasses import dataclass from dataclasses import dataclass
import re import re
import ast
import evaluate import evaluate
import random import random
import itertools import itertools
...@@ -26,7 +27,8 @@ from lm_eval.api import samplers ...@@ -26,7 +27,8 @@ from lm_eval.api import samplers
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
task_name: str = None names: str = None
task_name: str = None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
training_split: str = None training_split: str = None
...@@ -53,6 +55,8 @@ class TaskConfig(dict): ...@@ -53,6 +55,8 @@ class TaskConfig(dict):
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
use_prompt: str = None use_prompt: str = None
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self): def __post_init__(self):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of # force prompt-compatibility for some prompt regardless of
...@@ -60,6 +64,10 @@ class TaskConfig(dict): ...@@ -60,6 +64,10 @@ class TaskConfig(dict):
self.doc_to_text = self.template_aliases + self.doc_to_text self.doc_to_text = self.template_aliases + self.doc_to_text
self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
if self.names:
self.task_name = self.names[0]
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -267,7 +275,7 @@ class Task(abc.ABC): ...@@ -267,7 +275,7 @@ class Task(abc.ABC):
) )
# TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute # TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute
inst = self.construct_requests(doc=doc, ctx=fewshot_ctx, metadata=(self._config["task_name"], doc_id, 2)) inst = self.construct_requests(doc=doc, ctx=fewshot_ctx, metadata=(self._config["task_name"], doc_id, 1))
if not isinstance(inst, list): if not isinstance(inst, list):
inst = [inst] inst = [inst]
...@@ -404,12 +412,18 @@ class ConfigurableTask(Task): ...@@ -404,12 +412,18 @@ class ConfigurableTask(Task):
VERSION = "2.0" VERSION = "2.0"
OUTPUT_TYPE = None OUTPUT_TYPE = None
CONFIG = None
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
): ):
# if we are a subclass that has the CONFIG class attr set, ignore whatever is passed.
self._config = TaskConfig(**config) self._config = self.CONFIG
# else, if a config was passed as kwarg: use it
if (self._config is None) and config:
self._config = TaskConfig(**config)
if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg")
if self._config.output_type is not None: if self._config.output_type is not None:
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self._config.output_type
...@@ -534,8 +548,10 @@ class ConfigurableTask(Task): ...@@ -534,8 +548,10 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),) arguments=(self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
import ast # we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
return [ # TODO: any cleaner way to do this?
choices = ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc))
request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
...@@ -543,9 +559,30 @@ class ConfigurableTask(Task): ...@@ -543,9 +559,30 @@ class ConfigurableTask(Task):
idx=i, idx=i,
**kwargs, **kwargs,
) )
for i, choice in enumerate(ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc))) for i, choice in enumerate(choices)
# we pass the user-defined answer_choices var (in aliases) and echo the result. TODO: any cleaner way to do this?
] ]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments=(ctx, self._config.delimiter) arguments=(ctx, self._config.delimiter)
...@@ -574,21 +611,40 @@ class ConfigurableTask(Task): ...@@ -574,21 +611,40 @@ class ConfigurableTask(Task):
"bits_per_byte": (loglikelihood, bytes_), "bits_per_byte": (loglikelihood, bytes_),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy TODO: keep is_greedy to report exact_match as well on multiple choice probs lls = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
# TODO: remove dependence on "gold" and "choices" columns # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc))
if 2 * len(choices) == len(lls) and "acc_mutual_info" in self._metric_list.keys():
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
acc = 1.0 if np.argmax(lls) == gold else 0.0 acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]]) completion_len = np.array([float(len(i)) for i in choices])
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
# TODO: set which normalization metrics should be reported, and calculate them
# TODO: add mutual info.
result_dict = { result_dict = {
"acc": acc, "acc": acc,
"acc_norm": acc_norm, "acc_norm": acc_norm,
} }
# TODO: set which normalization metrics should be reported, and calculate them
if "exact_match" in self._metric_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = [res[1] for res in results] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in self._metric_list.keys():
lls_mutual_info = [ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)]
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None: if self._config.gold_alias is not None:
...@@ -626,7 +682,7 @@ class MultipleChoiceTask(Task): ...@@ -626,7 +682,7 @@ class MultipleChoiceTask(Task):
return " " + doc["choices"][doc["gold"]] return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
# TODO: add mutual info here?
return [Instance( return [Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
...@@ -705,8 +761,6 @@ class PerplexityTask(Task, abc.ABC): ...@@ -705,8 +761,6 @@ class PerplexityTask(Task, abc.ABC):
assert not ctx assert not ctx
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), idx=0, **kwargs) return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), idx=0, **kwargs)
# req = rf.loglikelihood_rolling(self.doc_to_target(doc))
# return req
def process_results(self, doc, results): def process_results(self, doc, results):
(loglikelihood,) = results (loglikelihood,) = results
...@@ -761,6 +815,38 @@ def register_task(*names): ...@@ -761,6 +815,38 @@ def register_task(*names):
return decorate return decorate
def register_yaml_task(yaml_path):
# same goal as register_task() but used to register yamls
import yaml
with open(yaml_path, "r") as f:
config = yaml.load(f, yaml.Loader)
from functools import partial
# TODO: strip whitespace from name?
# TODO: ensure num_fewshot overrides the config vals
def decorate(names, cls):
for name in names:
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import properly.
return cls
# we create a subclass that has subclass attr CONFIG = our yaml config, and decorate with the config's specified aliases
names = config['names']
yaml_task = decorate(
names,
type(config['names'][0] + 'ConfigurableTask', (ConfigurableTask,), {'CONFIG': TaskConfig(**config)})
)
##### Task registry utils and setup. ##### Task registry utils and setup.
# ALL_TASKS = sorted(list(TASK_REGISTRY)) # ALL_TASKS = sorted(list(TASK_REGISTRY))
......
...@@ -6,7 +6,7 @@ import lm_eval.api.metrics ...@@ -6,7 +6,7 @@ import lm_eval.api.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.api import lm_eval.api
from lm_eval.utils import positional_deprecated, run_task_tests, make_table from lm_eval.utils import positional_deprecated, run_task_tests, make_table, get_git_commit_hash
@positional_deprecated @positional_deprecated
...@@ -90,6 +90,7 @@ def simple_evaluate( ...@@ -90,6 +90,7 @@ def simple_evaluate(
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
} }
results["git_hash"] = get_git_commit_hash()
return results return results
......
from lm_eval.api.model import LM, MODEL_REGISTRY
from . import gpt2 from . import gpt2
from . import gpt3 from . import gpt3
from . import textsynth from . import textsynth
......
from pprint import pprint import os
from typing import List, Union
import sacrebleu from lm_eval.api.task import register_yaml_task
from lm_eval import api
# from . import superglue from .vanilla import *
# from . import glue
from . import arc
# from . import coqa
# from . import race
# from . import webqs
# from . import anli
# from . import wsc273
# from . import winogrande
# from . import quac
# from . import hellaswag
# from . import swag
# from . import openbookqa
# from . import squad
# from . import naturalqs
# from . import sat
# from . import arithmetic
from . import lambada
# from . import piqa
# from . import prost
# from . import mc_taco
# from . import triviaqa
# from . import pubmedqa
# from . import sciq
# from . import qasper
# from . import qa4mre
# from . import translation
# from . import headqa
# from . import mathqa
# from . import hendrycks_ethics
# from . import drop
# from . import unscramble
# from . import logiqa
# from . import hendrycks_test
# from . import hendrycks_math
# from . import cbt
# from . import lambada_cloze
from . import pile
from . import wikitext
# from . import lambada_multilingual
# from . import mutual
# from . import truthfulqa
# from . import blimp
# from . import asdiv
from . import gsm8k
# from . import storycloze
# from . import toxigen
# from . import crowspairs
# ######################################## # we want to register all yaml tasks in our .yaml folder.
# # Translation tasks yaml_dir = os.path.dirname(os.path.abspath(__file__)) + "/" + "yaml"
# ########################################
# # 6 total
# gpt3_translation_benchmarks = {
# "wmt14": ["en-fr", "fr-en"], # French
# "wmt16": ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
# }
for yaml in sorted(os.listdir(yaml_dir)):
# # 28 total yaml = os.path.join(yaml_dir, yaml)
# selected_translation_benchmarks = { register_yaml_task(yaml)
# **gpt3_translation_benchmarks,
# "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
# "iwslt17": ["en-ar", "ar-en"], # Arabic
# }
# # 319 total
# all_translation_benchmarks = {
# ts: sacrebleu.get_langpairs_for_testset(ts)
# for ts in sacrebleu.get_available_testsets()
# }
# ########################################
# # All tasks
# ########################################
# TASK_REGISTRY = {
# # GLUE
# # "cola": glue.CoLA,
# # "mnli": glue.MNLI,
# # "mnli_mismatched": glue.MNLIMismatched,
# # "mrpc": glue.MRPC,
# # "rte": glue.RTE,
# # "qnli": glue.QNLI,
# # "qqp": glue.QQP,
# # # "stsb": glue.STSB, # not implemented yet
# # "sst": glue.SST,
# # "wnli": glue.WNLI,
# # # SuperGLUE
# # "boolq": superglue.BoolQ,
# # "cb": superglue.CommitmentBank,
# # "copa": superglue.Copa,
# # "multirc": superglue.MultiRC,
# # "record": superglue.ReCoRD,
# # "wic": superglue.WordsInContext,
# # "wsc": superglue.SGWinogradSchemaChallenge,
# # # Order by benchmark/genre?
# # "coqa": coqa.CoQA,
# # "drop": drop.DROP,
# "lambada_openai": lambada.LambadaOpenAI,
# "lambada_standard": lambada.LambadaStandard,
# # "lambada_openai_cloze": lambada_cloze.LambadaOpenAICloze,
# # "lambada_standard_cloze": lambada_cloze.LambadaStandardCloze,
# # # multilingual lambada
# # **lambada_multilingual.construct_tasks(),
# "wikitext": wikitext.WikiText,
# # # "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# # # "cbt-ne": cbt.CBTNE, # disabled pending context length fix
# # "piqa": piqa.PiQA,
# # "prost": prost.PROST,
# # "mc_taco": mc_taco.MCTACO,
# # # Science related
# # "pubmedqa": pubmedqa.Pubmed_QA,
# # "sciq": sciq.SciQ,
# # "qasper": qasper.QASPER,
# # "qa4mre_2011": qa4mre.QA4MRE_2011,
# # "qa4mre_2012": qa4mre.QA4MRE_2012,
# # "qa4mre_2013": qa4mre.QA4MRE_2013,
# # "triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy,
# "arc_challenge": arc.ARCChallenge,
# # # "quac": quac.QuAC, # not implemented yet
# # "logiqa": logiqa.LogiQA,
# # "hellaswag": hellaswag.HellaSwag,
# # "swag": swag.SWAG,
# # "openbookqa": openbookqa.OpenBookQA,
# # "squad2": squad.SQuAD2,
# # "race": race.RACE,
# # # "naturalqs": naturalqs.NaturalQs, # not implemented yet
# # "headqa": headqa.HeadQAEsDeprecated, # for backwards compat - headqa used to default to es
# # "headqa_es": headqa.HeadQAEs,
# # "headqa_en": headqa.HeadQAEn,
# # "mathqa": mathqa.MathQA,
# # "webqs": webqs.WebQs,
# # "wsc273": wsc273.WinogradSchemaChallenge273,
# # "winogrande": winogrande.Winogrande,
# # "anli_r1": anli.ANLIRound1,
# # "anli_r2": anli.ANLIRound2,
# # "anli_r3": anli.ANLIRound3,
# # "ethics_cm": hendrycks_ethics.EthicsCM,
# # "ethics_deontology": hendrycks_ethics.EthicsDeontology,
# # "ethics_justice": hendrycks_ethics.EthicsJustice,
# # "ethics_utilitarianism_original": hendrycks_ethics.EthicsUtilitarianismOriginal,
# # "ethics_utilitarianism": hendrycks_ethics.EthicsUtilitarianism,
# # "ethics_virtue": hendrycks_ethics.EthicsVirtue,
# # "truthfulqa_mc": truthfulqa.TruthfulQAMultipleChoice,
# # "truthfulqa_gen": truthfulqa.TruthfulQAGeneration,
# # # dialogue
# # "mutual": mutual.MuTual,
# # "mutual_plus": mutual.MuTualPlus,
# # # math
# # "math_algebra": hendrycks_math.MathAlgebra,
# # "math_counting_and_prob": hendrycks_math.MathCountingAndProbability,
# # "math_geometry": hendrycks_math.MathGeometry,
# # "math_intermediate_algebra": hendrycks_math.MathIntermediateAlgebra,
# # "math_num_theory": hendrycks_math.MathNumberTheory,
# # "math_prealgebra": hendrycks_math.MathPrealgebra,
# # "math_precalc": hendrycks_math.MathPrecalculus,
# # "math_asdiv": asdiv.Asdiv,
# "gsm8k": gsm8k.GradeSchoolMath8K,
# # # arithmetic
# # "arithmetic_2da": arithmetic.Arithmetic2DPlus,
# # "arithmetic_2ds": arithmetic.Arithmetic2DMinus,
# # "arithmetic_3da": arithmetic.Arithmetic3DPlus,
# # "arithmetic_3ds": arithmetic.Arithmetic3DMinus,
# # "arithmetic_4da": arithmetic.Arithmetic4DPlus,
# # "arithmetic_4ds": arithmetic.Arithmetic4DMinus,
# # "arithmetic_5da": arithmetic.Arithmetic5DPlus,
# # "arithmetic_5ds": arithmetic.Arithmetic5DMinus,
# # "arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
# # "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# # # TODO Perhaps make these groups of tasks
# # # e.g. anli, arithmetic, openai_translations, harness_translations
# # # hendrycksTest (57 tasks)
# # **hendrycks_test.create_all_tasks(),
# # # e.g. wmt14-fr-en
# # **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# # # chef's selection, mostly wmt20
# # **translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
# # # Word Scrambling and Manipulation Tasks
# # "anagrams1": unscramble.Anagrams1,
# # "anagrams2": unscramble.Anagrams2,
# # "cycle_letters": unscramble.CycleLetters,
# # "random_insertion": unscramble.RandomInsertion,
# # "reversed_words": unscramble.ReversedWords,
# # # Pile
# # "pile_arxiv": pile.PileArxiv,
# # "pile_books3": pile.PileBooks3,
# # "pile_bookcorpus2": pile.PileBookCorpus2,
# # "pile_dm-mathematics": pile.PileDmMathematics,
# # "pile_enron": pile.PileEnron,
# # "pile_europarl": pile.PileEuroparl,
# # "pile_freelaw": pile.PileFreeLaw,
# # "pile_github": pile.PileGithub,
# # "pile_gutenberg": pile.PileGutenberg,
# # "pile_hackernews": pile.PileHackernews,
# # "pile_nih-exporter": pile.PileNIHExporter,
# # "pile_opensubtitles": pile.PileOpenSubtitles,
# # "pile_openwebtext2": pile.PileOpenWebText2,
# # "pile_philpapers": pile.PilePhilPapers,
# # "pile_pile-cc": pile.PilePileCc,
# # "pile_pubmed-abstracts": pile.PilePubmedAbstracts,
# # "pile_pubmed-central": pile.PilePubmedCentral,
# # "pile_stackexchange": pile.PileStackExchange,
# # "pile_uspto": pile.PileUspto,
# # "pile_ubuntu-irc": pile.PileUbuntuIrc,
# # "pile_wikipedia": pile.PileWikipedia,
# # "pile_youtubesubtitles": pile.PileYoutubeSubtitles,
# # # BLiMP
# # "blimp_adjunct_island": blimp.BlimpAdjunctIsland,
# # "blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
# # "blimp_anaphor_number_agreement": blimp.BlimpAnaphorNumberAgreement,
# # "blimp_animate_subject_passive": blimp.BlimpAnimateSubjectPassive,
# # "blimp_animate_subject_trans": blimp.BlimpAnimateSubjectTrans,
# # "blimp_causative": blimp.BlimpCausative,
# # "blimp_complex_NP_island": blimp.BlimpComplex_NPIsland,
# # "blimp_coordinate_structure_constraint_complex_left_branch": blimp.BlimpCoordinateStructureConstraintComplexLeftBranch,
# # "blimp_coordinate_structure_constraint_object_extraction": blimp.BlimpCoordinateStructureConstraintObjectExtraction,
# # "blimp_determiner_noun_agreement_1": blimp.BlimpDeterminerNounAgreement_1,
# # "blimp_determiner_noun_agreement_2": blimp.BlimpDeterminerNounAgreement_2,
# # "blimp_determiner_noun_agreement_irregular_1": blimp.BlimpDeterminerNounAgreementIrregular_1,
# # "blimp_determiner_noun_agreement_irregular_2": blimp.BlimpDeterminerNounAgreementIrregular_2,
# # "blimp_determiner_noun_agreement_with_adj_2": blimp.BlimpDeterminerNounAgreementWithAdj_2,
# # "blimp_determiner_noun_agreement_with_adj_irregular_1": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_1,
# # "blimp_determiner_noun_agreement_with_adj_irregular_2": blimp.BlimpDeterminerNounAgreementWithAdjIrregular_2,
# # "blimp_determiner_noun_agreement_with_adjective_1": blimp.BlimpDeterminerNounAgreementWithAdjective_1,
# # "blimp_distractor_agreement_relational_noun": blimp.BlimpDistractorAgreementRelationalNoun,
# # "blimp_distractor_agreement_relative_clause": blimp.BlimpDistractorAgreementRelativeClause,
# # "blimp_drop_argument": blimp.BlimpDropArgument,
# # "blimp_ellipsis_n_bar_1": blimp.BlimpEllipsisNBar_1,
# # "blimp_ellipsis_n_bar_2": blimp.BlimpEllipsisNBar_2,
# # "blimp_existential_there_object_raising": blimp.BlimpExistentialThereObjectRaising,
# # "blimp_existential_there_quantifiers_1": blimp.BlimpExistentialThereQuantifiers_1,
# # "blimp_existential_there_quantifiers_2": blimp.BlimpExistentialThereQuantifiers_2,
# # "blimp_existential_there_subject_raising": blimp.BlimpExistentialThereSubjectRaising,
# # "blimp_expletive_it_object_raising": blimp.BlimpExpletiveItObjectRaising,
# # "blimp_inchoative": blimp.BlimpInchoative,
# # "blimp_intransitive": blimp.BlimpIntransitive,
# # "blimp_irregular_past_participle_adjectives": blimp.BlimpIrregularPastParticipleAdjectives,
# # "blimp_irregular_past_participle_verbs": blimp.BlimpIrregularPastParticipleVerbs,
# # "blimp_irregular_plural_subject_verb_agreement_1": blimp.BlimpIrregularPluralSubjectVerbAgreement_1,
# # "blimp_irregular_plural_subject_verb_agreement_2": blimp.BlimpIrregularPluralSubjectVerbAgreement_2,
# # "blimp_left_branch_island_echo_question": blimp.BlimpLeftBranchIslandEchoQuestion,
# # "blimp_left_branch_island_simple_question": blimp.BlimpLeftBranchIslandSimpleQuestion,
# # "blimp_matrix_question_npi_licensor_present": blimp.BlimpMatrixQuestionNpiLicensorPresent,
# # "blimp_npi_present_1": blimp.BlimpNpiPresent_1,
# # "blimp_npi_present_2": blimp.BlimpNpiPresent_2,
# # "blimp_only_npi_licensor_present": blimp.BlimpOnlyNpiLicensorPresent,
# # "blimp_only_npi_scope": blimp.BlimpOnlyNpiScope,
# # "blimp_passive_1": blimp.BlimpPassive_1,
# # "blimp_passive_2": blimp.BlimpPassive_2,
# # "blimp_principle_A_c_command": blimp.BlimpPrinciple_ACCommand,
# # "blimp_principle_A_case_1": blimp.BlimpPrinciple_ACase_1,
# # "blimp_principle_A_case_2": blimp.BlimpPrinciple_ACase_2,
# # "blimp_principle_A_domain_1": blimp.BlimpPrinciple_ADomain_1,
# # "blimp_principle_A_domain_2": blimp.BlimpPrinciple_ADomain_2,
# # "blimp_principle_A_domain_3": blimp.BlimpPrinciple_ADomain_3,
# # "blimp_principle_A_reconstruction": blimp.BlimpPrinciple_AReconstruction,
# # "blimp_regular_plural_subject_verb_agreement_1": blimp.BlimpRegularPluralSubjectVerbAgreement_1,
# # "blimp_regular_plural_subject_verb_agreement_2": blimp.BlimpRegularPluralSubjectVerbAgreement_2,
# # "blimp_sentential_negation_npi_licensor_present": blimp.BlimpSententialNegationNpiLicensorPresent,
# # "blimp_sentential_negation_npi_scope": blimp.BlimpSententialNegationNpiScope,
# # "blimp_sentential_subject_island": blimp.BlimpSententialSubjectIsland,
# # "blimp_superlative_quantifiers_1": blimp.BlimpSuperlativeQuantifiers_1,
# # "blimp_superlative_quantifiers_2": blimp.BlimpSuperlativeQuantifiers_2,
# # "blimp_tough_vs_raising_1": blimp.BlimpToughVsRaising_1,
# # "blimp_tough_vs_raising_2": blimp.BlimpToughVsRaising_2,
# # "blimp_transitive": blimp.BlimpTransitive,
# # "blimp_wh_island": blimp.BlimpWhIsland,
# # "blimp_wh_questions_object_gap": blimp.BlimpWhQuestionsObjectGap,
# # "blimp_wh_questions_subject_gap": blimp.BlimpWhQuestionsSubjectGap,
# # "blimp_wh_questions_subject_gap_long_distance": blimp.BlimpWhQuestionsSubjectGapLongDistance,
# # "blimp_wh_vs_that_no_gap": blimp.BlimpWhVsThatNoGap,
# # "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
# # "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
# # "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
# # "toxigen": toxigen.ToxiGen,
# # "crows_pairs_english": crowspairs.CrowsPairsEnglish,
# # "crows_pairs_english_race_color": crowspairs.CrowsPairsEnglishRaceColor,
# # "crows_pairs_english_socioeconomic": crowspairs.CrowsPairsEnglishSocioeconomic,
# # "crows_pairs_english_gender": crowspairs.CrowsPairsEnglishGender,
# # "crows_pairs_english_age": crowspairs.CrowsPairsEnglishAge,
# # "crows_pairs_english_religion": crowspairs.CrowsPairsEnglishReligion,
# # "crows_pairs_english_disability": crowspairs.CrowsPairsEnglishDisability,
# # "crows_pairs_english_sexual_orientation": crowspairs.CrowsPairsEnglishSexualOrientation,
# # "crows_pairs_english_nationality": crowspairs.CrowsPairsEnglishNationality,
# # "crows_pairs_english_physical_appearance": crowspairs.CrowsPairsEnglishPhysicalAppearance,
# # "crows_pairs_english_autre": crowspairs.CrowsPairsEnglishAutre,
# # "crows_pairs_french": crowspairs.CrowsPairsFrench,
# # "crows_pairs_french_race_color": crowspairs.CrowsPairsFrenchRaceColor,
# # "crows_pairs_french_socioeconomic": crowspairs.CrowsPairsFrenchSocioeconomic,
# # "crows_pairs_french_gender": crowspairs.CrowsPairsFrenchGender,
# # "crows_pairs_french_age": crowspairs.CrowsPairsFrenchAge,
# # "crows_pairs_french_religion": crowspairs.CrowsPairsFrenchReligion,
# # "crows_pairs_french_disability": crowspairs.CrowsPairsFrenchDisability,
# # "crows_pairs_french_sexual_orientation": crowspairs.CrowsPairsFrenchSexualOrientation,
# # "crows_pairs_french_nationality": crowspairs.CrowsPairsFrenchNationality,
# # "crows_pairs_french_physical_appearance": crowspairs.CrowsPairsFrenchPhysicalAppearance,
# # "crows_pairs_french_autre": crowspairs.CrowsPairsFrenchAutre,
# # Requires manual download of data.
# # "storycloze_2016": storycloze.StoryCloze2016,
# # "storycloze_2018": storycloze.StoryCloze2018,
# # "sat": sat.SATAnalogies,
# }
# ALL_TASKS = sorted(list(TASK_REGISTRY))
# def get_task(task_name):
# try:
# return TASK_REGISTRY[task_name]
# except KeyError:
# print("Available tasks:")
# pprint(TASK_REGISTRY)
# raise KeyError(f"Missing task {task_name}")
# def get_task_name_from_object(task_object):
# for name, class_ in TASK_REGISTRY.items():
# if class_ is task_object:
# return name
# # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
# return (
# task_object.EVAL_HARNESS_NAME
# if hasattr(task_object, "EVAL_HARNESS_NAME")
# else type(task_object).__name__
# )
# def get_task_name_from_config(task_config):
# return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
# def get_task_dict(task_name_list: List[Union[str, dict, api.task.Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
# task_name_dict = {
# task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name})
# for task_name in task_name_list
# if isinstance(task_name, str)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): api.task.ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, api.task.Task)
# }
# assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
# return {
# **task_name_dict,
# **task_name_from_config_dict,
# **task_name_from_object_dict,
# }
\ No newline at end of file
from . import arc
from . import gsm8k
from . import lambada
from . import pile
from . import wikitext
# TODO: define via __all__
\ No newline at end of file
names:
- arc_challenge_yaml
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Challenge dataset_name: ARC-Challenge
output_type: multiple_choice output_type: multiple_choice
...@@ -13,4 +15,10 @@ metric_list: ...@@ -13,4 +15,10 @@ metric_list:
higher_is_better: true higher_is_better: true
- metric: acc_norm - metric: acc_norm
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
\ No newline at end of file - metric: acc_mutual_info
aggregation: mean
higher_is_better: true
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
\ No newline at end of file
names:
- arc_easy_yaml
dataset_path: ai2_arc
dataset_name: ARC-Easy
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int.
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
- metric: acc_mutual_info
aggregation: mean
higher_is_better: true
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
\ No newline at end of file
names:
- gsm8k_yaml
dataset_path: gsm8k dataset_path: gsm8k
dataset_name: main dataset_name: main
training_split: train training_split: train
......
names:
- lambada_openai_yaml
dataset_path: EleutherAI/lambada_openai dataset_path: EleutherAI/lambada_openai
dataset_name: default dataset_name: default
output_type: loglikelihood output_type: loglikelihood
......
names:
- pile_enron_yaml
dataset_path: EleutherAI/the_pile dataset_path: EleutherAI/the_pile
dataset_name: enron_emails dataset_name: enron_emails
output_type: loglikelihood_rolling output_type: loglikelihood_rolling
......
names:
- sglue_cb_yamltest
dataset_path: super_glue dataset_path: super_glue
dataset_name: cb dataset_name: cb
training_split: train training_split: train
......
...@@ -240,6 +240,19 @@ def run_task_tests(task_list: List[str]): ...@@ -240,6 +240,19 @@ def run_task_tests(task_list: List[str]):
) )
def get_git_commit_hash():
"""
Gets the git commit hash of your current repo (if it exists).
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
"""
try:
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
git_hash = git_hash.decode()
except subprocess.CalledProcessError:
git_hash = None
return git_hash
env = Environment(loader=BaseLoader, undefined=StrictUndefined) env = Environment(loader=BaseLoader, undefined=StrictUndefined)
......
...@@ -3,15 +3,16 @@ import json ...@@ -3,15 +3,16 @@ import json
import logging import logging
import fnmatch import fnmatch
import yaml import yaml
import os
from lm_eval import tasks, evaluator from lm_eval import evaluator, tasks
# import lm_eval.api.task
from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
ALL_TASKS = sorted(list(TASK_REGISTRY)) ALL_TASKS = sorted(list(TASK_REGISTRY))
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
self.choices = choices self.choices = choices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment