"vscode:/vscode.git/clone" did not exist on "beb524489ce8f36729bf58d23a6ed79e8639621f"
Commit 38244e15 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add yaml registering decorator

parent 95642aa6
...@@ -27,7 +27,8 @@ from lm_eval.api import samplers ...@@ -27,7 +27,8 @@ from lm_eval.api import samplers
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
task_name: str = None names: str = None
task_name: str = None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
training_split: str = None training_split: str = None
...@@ -54,6 +55,8 @@ class TaskConfig(dict): ...@@ -54,6 +55,8 @@ class TaskConfig(dict):
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
use_prompt: str = None use_prompt: str = None
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self): def __post_init__(self):
# allow user-specified aliases so that users can # allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of # force prompt-compatibility for some prompt regardless of
...@@ -61,6 +64,10 @@ class TaskConfig(dict): ...@@ -61,6 +64,10 @@ class TaskConfig(dict):
self.doc_to_text = self.template_aliases + self.doc_to_text self.doc_to_text = self.template_aliases + self.doc_to_text
self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
if self.names:
self.task_name = self.names[0]
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -268,7 +275,7 @@ class Task(abc.ABC): ...@@ -268,7 +275,7 @@ class Task(abc.ABC):
) )
# TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute # TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute
inst = self.construct_requests(doc=doc, ctx=fewshot_ctx, metadata=(self._config["task_name"], doc_id, 2)) inst = self.construct_requests(doc=doc, ctx=fewshot_ctx, metadata=(self._config["task_name"], doc_id, 1))
if not isinstance(inst, list): if not isinstance(inst, list):
inst = [inst] inst = [inst]
...@@ -405,12 +412,18 @@ class ConfigurableTask(Task): ...@@ -405,12 +412,18 @@ class ConfigurableTask(Task):
VERSION = "2.0" VERSION = "2.0"
OUTPUT_TYPE = None OUTPUT_TYPE = None
CONFIG = None
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
): ):
# if we are a subclass that has the CONFIG class attr set, ignore whatever is passed.
self._config = TaskConfig(**config) self._config = self.CONFIG
# else, if a config was passed as kwarg: use it
if (self._config is None) and config:
self._config = TaskConfig(**config)
if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg")
if self._config.output_type is not None: if self._config.output_type is not None:
self.OUTPUT_TYPE = self._config.output_type self.OUTPUT_TYPE = self._config.output_type
...@@ -620,7 +633,6 @@ class ConfigurableTask(Task): ...@@ -620,7 +633,6 @@ class ConfigurableTask(Task):
} }
# TODO: set which normalization metrics should be reported, and calculate them # TODO: set which normalization metrics should be reported, and calculate them
# TODO: add mutual info.
if "exact_match" in self._metric_list.keys(): if "exact_match" in self._metric_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
...@@ -670,7 +682,7 @@ class MultipleChoiceTask(Task): ...@@ -670,7 +682,7 @@ class MultipleChoiceTask(Task):
return " " + doc["choices"][doc["gold"]] return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
# TODO: add mutual info here?
return [Instance( return [Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
...@@ -803,6 +815,38 @@ def register_task(*names): ...@@ -803,6 +815,38 @@ def register_task(*names):
return decorate return decorate
def register_yaml_task(yaml_path):
# same goal as register_task() but used to register yamls
import yaml
with open(yaml_path, "r") as f:
config = yaml.load(f, yaml.Loader)
from functools import partial
# TODO: strip whitespace from name?
# TODO: ensure num_fewshot overrides the config vals
def decorate(names, cls):
for name in names:
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import properly.
return cls
# we create a subclass that has subclass attr CONFIG = our yaml config, and decorate with the config's specified aliases
names = config['names']
yaml_task = decorate(
names,
type(config['names'][0] + 'ConfigurableTask', (ConfigurableTask,), {'CONFIG': TaskConfig(**config)})
)
##### Task registry utils and setup. ##### Task registry utils and setup.
# ALL_TASKS = sorted(list(TASK_REGISTRY)) # ALL_TASKS = sorted(list(TASK_REGISTRY))
......
# from lm_eval.api.task import register_yaml_task import os
from lm_eval.api.task import register_yaml_task
from .vanilla import * from .vanilla import *
# we want to register all yaml tasks in our .yaml folder.
yaml_dir = os.path.dirname(os.path.abspath(__file__)) + "/" + "yaml"
for yaml in sorted(os.listdir(yaml_dir)):
yaml = os.path.join(yaml_dir, yaml)
register_yaml_task(yaml)
names:
- arc_challenge_yaml
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Challenge dataset_name: ARC-Challenge
output_type: multiple_choice output_type: multiple_choice
......
names:
- arc_easy_yaml
dataset_path: ai2_arc dataset_path: ai2_arc
dataset_name: ARC-Easy dataset_name: ARC-Easy
output_type: multiple_choice output_type: multiple_choice
......
names:
- gsm8k_yaml
dataset_path: gsm8k dataset_path: gsm8k
dataset_name: main dataset_name: main
training_split: train training_split: train
......
names:
- lambada_openai_yaml
dataset_path: EleutherAI/lambada_openai dataset_path: EleutherAI/lambada_openai
dataset_name: default dataset_name: default
output_type: loglikelihood output_type: loglikelihood
......
names:
- pile_enron_yaml
dataset_path: EleutherAI/the_pile dataset_path: EleutherAI/the_pile
dataset_name: enron_emails dataset_name: enron_emails
output_type: loglikelihood_rolling output_type: loglikelihood_rolling
......
names:
- sglue_cb_yamltest
dataset_path: super_glue dataset_path: super_glue
dataset_name: cb dataset_name: cb
training_split: train training_split: train
......
...@@ -3,15 +3,16 @@ import json ...@@ -3,15 +3,16 @@ import json
import logging import logging
import fnmatch import fnmatch
import yaml import yaml
import os
from lm_eval import tasks, evaluator from lm_eval import evaluator, tasks
# import lm_eval.api.task
from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY from lm_eval.api.task import ConfigurableTask, TASK_REGISTRY
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
ALL_TASKS = sorted(list(TASK_REGISTRY)) ALL_TASKS = sorted(list(TASK_REGISTRY))
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
self.choices = choices self.choices = choices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment