"vscode:/vscode.git/clone" did not exist on "1561edcd983905b8d5f662a55d00994a237c004b"
Commit ca0b8d45 authored by lintangsutawika's avatar lintangsutawika
Browse files

modified how yaml and python functions are added to groups and task registry

parent 275857a1
import os
task_registry = {}
group_registry = {}
task2func_index = {}
func2task_index = {}
def register_task(name):
def wrapper(func):
task_registry[name] = func
func2task_index[func.__name__] = name
task2func_index[name] = func.__name__
return func
return wrapper
def register_group(name):
def wrapper(func):
func_name = func2task_index[func.__name__]
if name in group_registry:
group_registry[name].append(
func_name
)
else:
group_registry[name] = [func_name]
return func
return wrapper
# @register_group('group_a')
# @register_task('a')
# def foo():
# pass
# @register_group('group_a')
# @register_task('b')
# def fii():
# pass
# @register_group('group_b')
# @register_task('c')
# def bar():
# pass
# name = 'A' # or args.type
# func_to_call = REGISTER[name]
# func_to_call() # actual call is done here
\ No newline at end of file
......@@ -26,10 +26,10 @@ from lm_eval.filters import build_filter_ensemble
@dataclass
class TaskConfig(yaml.YAMLObject):
yaml_tag = u'!task'
class TaskConfig(dict):
task: str = None
group: str = None
names: str = None
reference: str = None
task_name: str = None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
......@@ -89,7 +89,6 @@ class Task(abc.ABC):
VERSION = None
TASK_NAME: str = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: str = None
......@@ -430,7 +429,8 @@ class ConfigurableTask(Task):
self._config = TaskConfig(**config)
# Overwrite configs
else:
self._config.__dict__.update(config)
if config != None:
self._config.__dict__.update(config)
if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg")
......
import os
import re
import yaml
from typing import List, Union
from .vanilla import *
from lm_eval.utils import get_yaml_config, register_task
from lm_eval.api.task import Task, ConfigurableTask
YAML_REGISTRY = {}
FUNC_REGISTRY = register_task.all
BENCHMARK_REGISTRY = {}
from .arc import *
# we want to register all yaml tasks in our .yaml folder.
yaml_dir = os.path.dirname(os.path.abspath(__file__)) + "/" + "yaml"
for yaml_file in sorted(os.listdir(yaml_dir)):
yaml_path = os.path.join(yaml_dir, yaml_file)
names = re.sub(r"\.", "_", yaml_path.split("/")[-1])
YAML_REGISTRY[names] = yaml_path
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.register import (
register_task,
register_group,
task_registry,
group_registry
)
TASK_REGISTRY = list(YAML_REGISTRY.keys()) + list(FUNC_REGISTRY.keys())
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == []) and (len(file_list) > 0):
for file in file_list:
if "yaml" in file:
yaml_path = os.path.join(root, file)
try:
config = yaml.full_load(open(yaml_path, "rb"))
SubClass = type(
config['task']+'ConfigurableTask',
(ConfigurableTask,),
{'CONFIG': TaskConfig(**config)}
)
if 'task' in config:
register_task(config['task'])(SubClass)
if 'group' in config:
for group in config['group']:
register_group(group)(SubClass)
except:
pass
TASK_REGISTRY = task_registry
GROUP_REGISTRY = group_registry
ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name, task_config):
if task_name in TASK_REGISTRY:
if task_name in YAML_REGISTRY:
return ConfigurableTask(
config={
**get_yaml_config(YAML_REGISTRY[task_name]),
**task_config
}
)
elif task_name in FUNC_REGISTRY:
return FUNC_REGISTRY[task_name](
config=task_config
)
else:
def get_task(task_name, config):
try:
return TASK_REGISTRY[task_name](config)
except KeyError:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
......@@ -63,33 +69,76 @@ def get_task_name_from_config(task_config):
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None):
task_name_from_registry_dict = {
task_name: get_task(
task_name=task_name,
task_config={"num_fewshot": num_fewshot if num_fewshot else 0}
)
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_config_dict = {
get_task_name_from_config(task_config): ConfigurableTask(
config=task_config
)
for task_config in task_name_list
if isinstance(task_config, dict)
}
# TODO: Do we still need this?
def get_task_dict(task_name_list: List[Union[str, dict, Task]], config, **kwargs):
task_name_from_registry_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
for task_element in task_name_list:
if isinstance(task_element, str):
if task_element in GROUP_REGISTRY:
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(
task_name=task_name, config=config
)
}
else:
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(
task_name=task_element, config=config
)
}
elif isinstance(task_element, dict):
task_name_from_config_dict = {
**task_name_from_config_dict,
get_task_name_from_config(task_element): ConfigurableTask(
config=config
)
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element
}
# task_name_from_registry_dict = {
# task_name: get_task(
# task_name=task_name,
# task_config=config
# )
# for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name]
# if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# # TODO: Do we still need this?
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, Task)
# }
# assert set(task_name_from_registry_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
assert set(task_name_from_registry_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {
**task_name_from_registry_dict,
**task_name_from_config_dict,
# **task_name_from_object_dict,
**task_name_from_object_dict,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment