Commit cf79b1c6 authored by lintangsutawika's avatar lintangsutawika
Browse files

configurable can be overidden with newer configs

parent cac85281
......@@ -11,16 +11,15 @@ import functools
import datasets
import numpy as np
from typing import List, Union
from lm_eval import utils
from lm_eval.api.metrics import METRIC_REGISTRY, AGGREGATION_REGISTRY, HIGHER_IS_BETTER_REGISTRY
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval.api.metrics import (
METRIC_REGISTRY, AGGREGATION_REGISTRY, HIGHER_IS_BETTER_REGISTRY,
get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
)
from lm_eval.tasks import TASK_REGISTRY
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
......@@ -84,6 +83,8 @@ class Task(abc.ABC):
"""
VERSION = None
TASK_NAME: str = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: str = None
......@@ -418,13 +419,15 @@ class ConfigurableTask(Task):
def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
):
# if we are a subclass that has the CONFIG class attr set, ignore whatever is passed.
# Get pre-configured attributes
self._config = self.CONFIG
# else, if a config was passed as kwarg: use it
if (self._config is None) and config:
# Use new configurations if there was no preconfiguration
if self._config is None:
self._config = TaskConfig(**config)
elif config["num_fewshot"] != 0:
self._config.num_fewshot = config["num_fewshot"]
# Overwrite configs
else:
self._config.__dict__.update(config)
if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg")
......@@ -791,118 +794,3 @@ class PerplexityTask(Task, abc.ABC):
def count_words(cls, doc):
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
# TODO: confirm we want this to go in this file
TASK_REGISTRY = {}
ALL_TASKS = []
def register_task(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import right.
return cls
return decorate
def register_yaml_task(yaml_path):
# same goal as register_task() but used to register yamls
import yaml
with open(yaml_path, "r") as f:
config = yaml.load(f, yaml.Loader)
from functools import partial
# TODO: strip whitespace from name?
# TODO: ensure num_fewshot overrides the config vals
def decorate(names, cls):
for name in names:
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import properly.
return cls
# we create a subclass that has subclass attr CONFIG = our yaml config, and decorate with the config's specified aliases
names = config['names']
yaml_task = decorate(
names,
type(config['names'][0] + 'ConfigurableTask', (ConfigurableTask,), {'CONFIG': TaskConfig(**config)})
)
##### Task registry utils and setup.
# ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name):
try:
return TASK_REGISTRY[task_name]
except KeyError:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
task_name_dict = {
task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0})
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_config_dict = {
get_task_name_from_config(task_config): ConfigurableTask(
config=task_config
)
for task_config in task_name_list
if isinstance(task_config, dict)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list
if isinstance(task_object, Task)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {
**task_name_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment