"include/ck/utility/amd_inline_asm.hpp" did not exist on "6fe3627a9eb35f1237266f1b6cc8fd3456aed67d"
Unverified Commit 50423907 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #581 from EleutherAI/all_tasks_mutable

[Refactor] ALL_TASKS now maintained (not static)
parents 0a1ced22 3c284567
...@@ -31,7 +31,7 @@ def get_model(model_name): ...@@ -31,7 +31,7 @@ def get_model(model_name):
TASK_REGISTRY = {} TASK_REGISTRY = {}
GROUP_REGISTRY = {} GROUP_REGISTRY = {}
ALL_TASKS = [] ALL_TASKS = set()
func2task_index = {} func2task_index = {}
...@@ -42,6 +42,7 @@ def register_task(name): ...@@ -42,6 +42,7 @@ def register_task(name):
), f"task named '{name}' conflicts with existing registered task!" ), f"task named '{name}' conflicts with existing registered task!"
TASK_REGISTRY[name] = fn TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name func2task_index[fn.__name__] = name
return fn return fn
...@@ -55,6 +56,7 @@ def register_group(name): ...@@ -55,6 +56,7 @@ def register_group(name):
GROUP_REGISTRY[name].append(func_name) GROUP_REGISTRY[name].append(func_name)
else: else:
GROUP_REGISTRY[name] = [func_name] GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn return fn
return decorate return decorate
......
...@@ -12,6 +12,7 @@ from lm_eval.api.registry import ( ...@@ -12,6 +12,7 @@ from lm_eval.api.registry import (
register_group, register_group,
TASK_REGISTRY, TASK_REGISTRY,
GROUP_REGISTRY, GROUP_REGISTRY,
ALL_TASKS,
) )
...@@ -58,8 +59,6 @@ def include_task_folder(task_dir): ...@@ -58,8 +59,6 @@ def include_task_folder(task_dir):
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir) include_task_folder(task_dir)
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
def get_task(task_name, config): def get_task(task_name, config):
try: try:
......
...@@ -4,12 +4,10 @@ import fnmatch ...@@ -4,12 +4,10 @@ import fnmatch
import argparse import argparse
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.registry import GROUP_REGISTRY, TASK_REGISTRY from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
...@@ -34,7 +32,7 @@ def parse_args(): ...@@ -34,7 +32,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True) parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="") parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=MultiChoice(ALL_TASKS)) parser.add_argument("--tasks", default=None, choices=MultiChoice(sorted(ALL_TASKS)))
parser.add_argument("--config", default=None) parser.add_argument("--config", default=None)
parser.add_argument("--provide_description", action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment