Commit ca0b8d45 authored by lintangsutawika's avatar lintangsutawika
Browse files

modified how yaml and python functions are added to groups and task registry

parent 275857a1
import os
task_registry = {}
group_registry = {}
task2func_index = {}
func2task_index = {}
def register_task(name):
def wrapper(func):
task_registry[name] = func
func2task_index[func.__name__] = name
task2func_index[name] = func.__name__
return func
return wrapper
def register_group(name):
def wrapper(func):
func_name = func2task_index[func.__name__]
if name in group_registry:
group_registry[name].append(
func_name
)
else:
group_registry[name] = [func_name]
return func
return wrapper
# @register_group('group_a')
# @register_task('a')
# def foo():
# pass
# @register_group('group_a')
# @register_task('b')
# def fii():
# pass
# @register_group('group_b')
# @register_task('c')
# def bar():
# pass
# name = 'A' # or args.type
# func_to_call = REGISTER[name]
# func_to_call() # actual call is done here
\ No newline at end of file
...@@ -26,10 +26,10 @@ from lm_eval.filters import build_filter_ensemble ...@@ -26,10 +26,10 @@ from lm_eval.filters import build_filter_ensemble
@dataclass @dataclass
class TaskConfig(yaml.YAMLObject): class TaskConfig(dict):
yaml_tag = u'!task'
task: str = None
group: str = None
names: str = None names: str = None
reference: str = None reference: str = None
task_name: str = None # TODO: deprecate this, it'll be set in __post_init__ to be names[0] task_name: str = None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
...@@ -89,7 +89,6 @@ class Task(abc.ABC): ...@@ -89,7 +89,6 @@ class Task(abc.ABC):
VERSION = None VERSION = None
TASK_NAME: str = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script. # or a path to a custom `datasets` loading script.
DATASET_PATH: str = None DATASET_PATH: str = None
...@@ -430,7 +429,8 @@ class ConfigurableTask(Task): ...@@ -430,7 +429,8 @@ class ConfigurableTask(Task):
self._config = TaskConfig(**config) self._config = TaskConfig(**config)
# Overwrite configs # Overwrite configs
else: else:
self._config.__dict__.update(config) if config != None:
self._config.__dict__.update(config)
if self._config is None: if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg") raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg")
......
import os import os
import re import re
import yaml
from typing import List, Union from typing import List, Union
from .vanilla import * from .arc import *
from lm_eval.utils import get_yaml_config, register_task
from lm_eval.api.task import Task, ConfigurableTask
YAML_REGISTRY = {}
FUNC_REGISTRY = register_task.all
BENCHMARK_REGISTRY = {}
# we want to register all yaml tasks in our .yaml folder. from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
yaml_dir = os.path.dirname(os.path.abspath(__file__)) + "/" + "yaml" from lm_eval.api.register import (
for yaml_file in sorted(os.listdir(yaml_dir)): register_task,
yaml_path = os.path.join(yaml_dir, yaml_file) register_group,
task_registry,
names = re.sub(r"\.", "_", yaml_path.split("/")[-1]) group_registry
YAML_REGISTRY[names] = yaml_path )
TASK_REGISTRY = list(YAML_REGISTRY.keys()) + list(FUNC_REGISTRY.keys()) task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == []) and (len(file_list) > 0):
for file in file_list:
if "yaml" in file:
yaml_path = os.path.join(root, file)
try:
config = yaml.full_load(open(yaml_path, "rb"))
SubClass = type(
config['task']+'ConfigurableTask',
(ConfigurableTask,),
{'CONFIG': TaskConfig(**config)}
)
if 'task' in config:
register_task(config['task'])(SubClass)
if 'group' in config:
for group in config['group']:
register_group(group)(SubClass)
except:
pass
TASK_REGISTRY = task_registry
GROUP_REGISTRY = group_registry
ALL_TASKS = sorted(list(TASK_REGISTRY)) ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name, config):
def get_task(task_name, task_config): try:
return TASK_REGISTRY[task_name](config)
if task_name in TASK_REGISTRY: except KeyError:
if task_name in YAML_REGISTRY:
return ConfigurableTask(
config={
**get_yaml_config(YAML_REGISTRY[task_name]),
**task_config
}
)
elif task_name in FUNC_REGISTRY:
return FUNC_REGISTRY[task_name](
config=task_config
)
else:
print("Available tasks:") print("Available tasks:")
pprint(TASK_REGISTRY) pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}") raise KeyError(f"Missing task {task_name}")
...@@ -63,33 +69,76 @@ def get_task_name_from_config(task_config): ...@@ -63,33 +69,76 @@ def get_task_name_from_config(task_config):
# TODO: pass num_fewshot and other cmdline overrides in a better way # TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None): def get_task_dict(task_name_list: List[Union[str, dict, Task]], config, **kwargs):
task_name_from_registry_dict = {
task_name: get_task( task_name_from_registry_dict = {}
task_name=task_name, task_name_from_config_dict = {}
task_config={"num_fewshot": num_fewshot if num_fewshot else 0} task_name_from_object_dict = {}
)
for task_name in task_name_list for task_element in task_name_list:
if isinstance(task_name, str) if isinstance(task_element, str):
}
task_name_from_config_dict = { if task_element in GROUP_REGISTRY:
get_task_name_from_config(task_config): ConfigurableTask( for task_name in GROUP_REGISTRY[task_element]:
config=task_config if task_name not in task_name_from_registry_dict:
) task_name_from_registry_dict = {
for task_config in task_name_list **task_name_from_registry_dict,
if isinstance(task_config, dict) task_name: get_task(
} task_name=task_name, config=config
# TODO: Do we still need this? )
}
else:
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(
task_name=task_element, config=config
)
}
elif isinstance(task_element, dict):
task_name_from_config_dict = {
**task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask(
config=config
)
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element
}
# task_name_from_registry_dict = {
# task_name: get_task(
# task_name=task_name,
# task_config=config
# )
# for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name]
# if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# # TODO: Do we still need this?
# task_name_from_object_dict = { # task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object # get_task_name_from_object(task_object): task_object
# for task_object in task_name_list # for task_object in task_name_list
# if isinstance(task_object, Task) # if isinstance(task_object, Task)
# } # }
# assert set(task_name_from_registry_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
assert set(task_name_from_registry_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return { return {
**task_name_from_registry_dict, **task_name_from_registry_dict,
**task_name_from_config_dict, **task_name_from_config_dict,
# **task_name_from_object_dict, **task_name_from_object_dict,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment