__init__.py 5.1 KB
Newer Older
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
from typing import List, Union
&'s avatar
& committed
3

4
5
from .gsm8k import *
from .triviaqa import *
lintangsutawika's avatar
lintangsutawika committed
6

7
from lm_eval import utils
lintangsutawika's avatar
lintangsutawika committed
8
from lm_eval.logger import eval_logger
9
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
lintangsutawika's avatar
lintangsutawika committed
10
from lm_eval.api.registry import (
11
12
    register_task,
    register_group,
lintangsutawika's avatar
lintangsutawika committed
13
14
    TASK_REGISTRY,
    GROUP_REGISTRY,
lintangsutawika's avatar
lintangsutawika committed
15
)
lintangsutawika's avatar
lintangsutawika committed
16

17

18
def get_task_name_from_config(task_config):
19
    return "{dataset_path}_{dataset_name}".format(**task_config)
20
21


22
23
24
25
26
27
28
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
for root, subdirs, file_list in os.walk(task_dir):
    if (subdirs == []) and (len(file_list) > 0):
        for file in file_list:
            if "yaml" in file:
                yaml_path = os.path.join(root, file)
                try:
29
                    config = utils.load_yaml_config(yaml_path)
30
31

                    SubClass = type(
lintangsutawika's avatar
lintangsutawika committed
32
                        config["task"] + "ConfigurableTask",
33
                        (ConfigurableTask,),
lintangsutawika's avatar
lintangsutawika committed
34
                        {"CONFIG": TaskConfig(**config)},
35
36
                    )

lintangsutawika's avatar
lintangsutawika committed
37
                    if "task" in config:
38
                        task_name = "{}".format(config["task"])
39
40
                        register_task(task_name)(SubClass)

lintangsutawika's avatar
lintangsutawika committed
41
42
                    if "group" in config:
                        for group in config["group"]:
43
                            register_group(group)(SubClass)
lintangsutawika's avatar
lintangsutawika committed
44
                except Exception as error:
lintangsutawika's avatar
lintangsutawika committed
45
                    eval_logger.warning(
46
                        "Failed to load config in\n"
lintangsutawika's avatar
lintangsutawika committed
47
48
                        f"                                 {yaml_path}\n"
                        "                                 Config will not be added to registry"
lintangsutawika's avatar
lintangsutawika committed
49
                        f"                                 Error: {error}"
lintangsutawika's avatar
lintangsutawika committed
50
                    )
51

lintangsutawika's avatar
lintangsutawika committed
52
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
lintangsutawika's avatar
lintangsutawika committed
53

lintangsutawika's avatar
lintangsutawika committed
54

55
56
def get_task(task_name, config):
    try:
57
        return TASK_REGISTRY[task_name](config=config)
58
    except KeyError:
lintangsutawika's avatar
lintangsutawika committed
59
        eval_logger.info("Available tasks:")
lintangsutawika's avatar
update  
lintangsutawika committed
60
        eval_logger.info(ALL_TASKS)
lintangsutawika's avatar
lintangsutawika committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        raise KeyError(f"Missing task {task_name}")


def get_task_name_from_object(task_object):
    for name, class_ in TASK_REGISTRY.items():
        if class_ is task_object:
            return name

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )


# TODO: pass num_fewshot and other cmdline overrides in a better way
lintangsutawika's avatar
lintangsutawika committed
79
80
81
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):

    config = {**kwargs}
82
83
84
85
86
87
88
89
90
91
92
93
94

    task_name_from_registry_dict = {}
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

    for task_element in task_name_list:
        if isinstance(task_element, str):

            if task_element in GROUP_REGISTRY:
                for task_name in GROUP_REGISTRY[task_element]:
                    if task_name not in task_name_from_registry_dict:
                        task_name_from_registry_dict = {
                            **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
95
96
                            task_name: get_task(task_name=task_name, config=config),
                        }
97
            else:
98
                task_name = task_element
99
100
101
                if task_name not in task_name_from_registry_dict:
                    task_name_from_registry_dict = {
                        **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
102
103
                        task_name: get_task(task_name=task_element, config=config),
                    }
104
105

        elif isinstance(task_element, dict):
106
            task_element.update(config)
107
108
109
            task_name_from_config_dict = {
                **task_name_from_config_dict,
                get_task_name_from_config(task_element): ConfigurableTask(
110
                    config=task_element
lintangsutawika's avatar
lintangsutawika committed
111
                ),
112
113
114
115
116
117
            }

        elif isinstance(task_element, Task):

            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
118
                get_task_name_from_object(task_element): task_element,
119
            }
lintangsutawika's avatar
lintangsutawika committed
120

121
122
123
124
125
    # task_name_from_registry_dict = {
    #     task_name: get_task(
    #         task_name=task_name,
    #         task_config=config
    #     )
lintangsutawika's avatar
lintangsutawika committed
126
    #     for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name]
127
128
129
130
131
132
133
134
135
136
    #     if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY)
    # }
    # task_name_from_config_dict = {
    #     get_task_name_from_config(task_config): ConfigurableTask(
    #         config=task_config
    #     )
    #     for task_config in task_name_list
    #     if isinstance(task_config, dict)
    # }
    # # TODO: Do we still need this?
lintangsutawika's avatar
lintangsutawika committed
137
138
139
140
141
    # task_name_from_object_dict = {
    #     get_task_name_from_object(task_object): task_object
    #     for task_object in task_name_list
    #     if isinstance(task_object, Task)
    # }
142

lintangsutawika's avatar
lintangsutawika committed
143
144
145
    assert set(task_name_from_registry_dict.keys()).isdisjoint(
        set(task_name_from_object_dict.keys())
    )
lintangsutawika's avatar
lintangsutawika committed
146
147
148
    return {
        **task_name_from_registry_dict,
        **task_name_from_config_dict,
149
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
150
    }