Commit f275301a authored by haileyschoelkopf's avatar haileyschoelkopf Committed by Hailey Schoelkopf
Browse files

make tasks and models registered by decorators

parent e7c18e53
...@@ -2,6 +2,29 @@ import abc ...@@ -2,6 +2,29 @@ import abc
from lm_eval import utils from lm_eval import utils
MODEL_REGISTRY = {}
def register_model(name):
# TODO: should fairseq/elk be cited for this design pattern?
def decorate(cls):
assert (
issubclass(cls, LM)
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
name not in MODEL_REGISTRY
), f"Model named '{name}' conflicts with existing model!"
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
return MODEL_REGISTRY[model_name]
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self): def __init__(self):
......
...@@ -9,6 +9,8 @@ import itertools ...@@ -9,6 +9,8 @@ import itertools
import datasets import datasets
import numpy as np import numpy as np
from typing import List, Union
from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY from lm_eval.api import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval.api.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
...@@ -31,7 +33,7 @@ class TaskConfig(dict): ...@@ -31,7 +33,7 @@ class TaskConfig(dict):
# TODO: add this as more jinja2 appended to start of jinja2 templates. Should allow users to set vars # TODO: add this as more jinja2 appended to start of jinja2 templates. Should allow users to set vars
# s.t. they can define e.g. {% set question = query %} to map dataset columns to "canonical" names in prompts. # s.t. they can define e.g. {% set question = query %} to map dataset columns to "canonical" names in prompts.
template_vars: str = None template_aliases: str = None
doc_to_text: str = None doc_to_text: str = None
doc_to_target: str = None doc_to_target: str = None
...@@ -609,3 +611,82 @@ class PerplexityTask(Task, abc.ABC): ...@@ -609,3 +611,82 @@ class PerplexityTask(Task, abc.ABC):
def count_words(cls, doc): def count_words(cls, doc):
"""Downstream tasks with custom word boundaries should override this!""" """Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
# TODO: confirm we want this to go in this file
TASK_REGISTRY = {}
ALL_TASKS = []
def register_task(name):
def decorate(cls):
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task!"
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import right.
return cls
return decorate
##### Task registry utils and setup.
# ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name):
try:
return TASK_REGISTRY[task_name]
except KeyError:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
task_name_dict = {
task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name})
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_config_dict = {
get_task_name_from_config(task_config): ConfigurableTask(
config=task_config
)
for task_config in task_name_list
if isinstance(task_config, dict)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list
if isinstance(task_object, Task)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {
**task_name_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
\ No newline at end of file
...@@ -58,14 +58,14 @@ def simple_evaluate( ...@@ -58,14 +58,14 @@ def simple_evaluate(
if isinstance(model, str): if isinstance(model, str):
if model_args is None: if model_args is None:
model_args = "" model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string( lm = lm_eval.api.model.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device} model_args, {"batch_size": batch_size, "device": device}
) )
else: else:
assert isinstance(model, lm_eval.api.model.LM) assert isinstance(model, lm_eval.api.model.LM)
lm = model lm = model
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot) task_dict = lm_eval.api.task.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
......
from lm_eval.api.model import LM, MODEL_REGISTRY
from . import gpt2 from . import gpt2
from . import gpt3 from . import gpt3
from . import textsynth from . import textsynth
from . import dummy from . import dummy
MODEL_REGISTRY = { # MODEL_REGISTRY = {}
"hf-causal": gpt2.HFLM,
"openai": gpt3.GPT3LM,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
}
# MODEL_REGISTRY = {
# "hf-causal": gpt2.HFLM,
# "openai": gpt3.GPT3LM,
# "textsynth": textsynth.TextSynthLM,
# "dummy": dummy.DummyLM,
# }
def get_model(model_name): # def get_model(model_name):
return MODEL_REGISTRY[model_name] # return MODEL_REGISTRY[model_name]
...@@ -6,9 +6,11 @@ from tqdm import tqdm ...@@ -6,9 +6,11 @@ from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM, register_model
# from lm_eval.models import register_model
@register_model("hf-causal")
class HFLM(LM): class HFLM(LM):
def __init__( def __init__(
self, self,
......
This diff is collapsed.
...@@ -12,7 +12,7 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi ...@@ -12,7 +12,7 @@ a co-occurrence method fail to answer correctly) and an Easy Set of 5,197 questi
Homepage: https://allenai.org/data/arc Homepage: https://allenai.org/data/arc
""" """
from lm_eval.api.task import MultipleChoiceTask from lm_eval.api.task import MultipleChoiceTask, register_task
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
from lm_eval import utils from lm_eval import utils
...@@ -28,7 +28,7 @@ _CITATION = """ ...@@ -28,7 +28,7 @@ _CITATION = """
} }
""" """
@register_task("arc_easy")
class ARCEasy(MultipleChoiceTask): class ARCEasy(MultipleChoiceTask):
VERSION = "2.0" VERSION = "2.0"
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
...@@ -80,6 +80,7 @@ class ARCEasy(MultipleChoiceTask): ...@@ -80,6 +80,7 @@ class ARCEasy(MultipleChoiceTask):
return doc["query"] return doc["query"]
@register_task("arc_challenge")
class ARCChallenge(ARCEasy): class ARCChallenge(ARCEasy):
DATASET_PATH = "ai2_arc" DATASET_PATH = "ai2_arc"
DATASET_NAME = "ARC-Challenge" DATASET_NAME = "ARC-Challenge"
...@@ -17,7 +17,7 @@ model's sample/generation function. ...@@ -17,7 +17,7 @@ model's sample/generation function.
Homepage: https://github.com/openai/grade-school-math Homepage: https://github.com/openai/grade-school-math
""" """
import re import re
from lm_eval.api.task import Task from lm_eval.api.task import Task, register_task
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean from lm_eval.api.metrics import mean
...@@ -41,6 +41,7 @@ ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") ...@@ -41,6 +41,7 @@ ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]" INVALID_ANS = "[invalid]"
@register_task("gsm8k")
class GradeSchoolMath8K(Task): class GradeSchoolMath8K(Task):
VERSION = 0 VERSION = 0
DATASET_PATH = "gsm8k" DATASET_PATH = "gsm8k"
......
...@@ -12,7 +12,7 @@ in the broader discourse. ...@@ -12,7 +12,7 @@ in the broader discourse.
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
""" """
from lm_eval.api.task import Task from lm_eval.api.task import Task, register_task
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity from lm_eval.api.metrics import mean, perplexity
...@@ -75,6 +75,7 @@ class LambadaBase(Task): ...@@ -75,6 +75,7 @@ class LambadaBase(Task):
return {"ppl": False, "acc": True} return {"ppl": False, "acc": True}
@register_task("lambada_standard")
class LambadaStandard(LambadaBase): class LambadaStandard(LambadaBase):
"""The LAMBADA task using the standard original LAMBADA dataset.""" """The LAMBADA task using the standard original LAMBADA dataset."""
...@@ -90,7 +91,7 @@ class LambadaStandard(LambadaBase): ...@@ -90,7 +91,7 @@ class LambadaStandard(LambadaBase):
def has_test_docs(self): def has_test_docs(self):
return True return True
@register_task("lambada_openai")
class LambadaOpenAI(LambadaBase): class LambadaOpenAI(LambadaBase):
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the """The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model. original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.
......
...@@ -10,7 +10,7 @@ NOTE: This `Task` is based on WikiText-2. ...@@ -10,7 +10,7 @@ NOTE: This `Task` is based on WikiText-2.
Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ Homepage: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
""" """
import re import re
from lm_eval.api.task import PerplexityTask from lm_eval.api.task import PerplexityTask, register_task
_CITATION = """ _CITATION = """
...@@ -58,7 +58,7 @@ def wikitext_detokenizer(string): ...@@ -58,7 +58,7 @@ def wikitext_detokenizer(string):
return string return string
@register_task("wikitext")
class WikiText(PerplexityTask): class WikiText(PerplexityTask):
VERSION = "2.0" VERSION = "2.0"
DATASET_PATH = "EleutherAI/wikitext_document_level" DATASET_PATH = "EleutherAI/wikitext_document_level"
......
...@@ -5,14 +5,17 @@ import fnmatch ...@@ -5,14 +5,17 @@ import fnmatch
import yaml import yaml
from lm_eval import tasks, evaluator from lm_eval import tasks, evaluator
from lm_eval.api.task import ConfigurableTask # import lm_eval.api.task
from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
ALL_TASKS = sorted(list(TASK_REGISTRY))
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
self.choices = choices self.choices = choices
print(f"{ALL_TASKS} is this")
# Simple wildcard support (linux filename patterns) # Simple wildcard support (linux filename patterns)
def __contains__(self, values): def __contains__(self, values):
...@@ -31,7 +34,7 @@ def parse_args(): ...@@ -31,7 +34,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True) parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="") parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS)) parser.add_argument("--tasks", default=None, choices=MultiChoice(ALL_TASKS))
parser.add_argument("--config", default=None) parser.add_argument("--config", default=None)
parser.add_argument("--provide_description", action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
...@@ -80,9 +83,9 @@ def main(): ...@@ -80,9 +83,9 @@ def main():
task_names.append(config) task_names.append(config)
else: else:
task_names = tasks.ALL_TASKS task_names = ALL_TASKS
else: else:
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS) task_names = pattern_match(args.tasks.split(","), ALL_TASKS)
print(f"Selected Tasks: {task_names}") print(f"Selected Tasks: {task_names}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment