Commit 4bc837d4 authored by lintangsutawika's avatar lintangsutawika
Browse files

modified registry process

parent 35730ace
import os
import re
from lm_eval.api.task import register_yaml_task
from typing import List, Union
from .vanilla import *
from lm_eval.utils import get_yaml_config, register_task
from lm_eval.api.task import Task, ConfigurableTask
YAML_REGISTRY = {}
FUNC_REGISTRY = register_task.all
BENCHMARK_REGISTRY = {}
# we want to register all yaml tasks in our .yaml folder.
yaml_dir = os.path.dirname(os.path.abspath(__file__)) + "/" + "yaml"
for yaml_file in sorted(os.listdir(yaml_dir)):
yaml_path = os.path.join(yaml_dir, yaml_file)
names = re.sub(r"\.", "_", yaml_path.split("/")[-1])
YAML_REGISTRY[names] = yaml_path
TASK_REGISTRY = list(YAML_REGISTRY.keys()) + list(FUNC_REGISTRY.keys())
ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name, task_config):
if task_name in TASK_REGISTRY:
if task_name in YAML_REGISTRY:
return ConfigurableTask(
config={
**get_yaml_config(YAML_REGISTRY[task_name]),
**task_config
}
)
elif task_name in FUNC_REGISTRY:
return FUNC_REGISTRY[task_name](
config=task_config
)
else:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None):
task_name_from_registry_dict = {
task_name: get_task(
task_name=task_name,
task_config={"num_fewshot": num_fewshot if num_fewshot else 0}
)
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_config_dict = {
get_task_name_from_config(task_config): ConfigurableTask(
config=task_config
)
for task_config in task_name_list
if isinstance(task_config, dict)
}
# TODO: Do we still need this?
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, Task)
# }
# assert set(task_name_from_registry_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {
**task_name_from_registry_dict,
**task_name_from_config_dict,
# **task_name_from_object_dict,
}
for yaml in sorted(os.listdir(yaml_dir)):
yaml = os.path.join(yaml_dir, yaml)
register_yaml_task(yaml)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment