Unverified Commit 0c53ff49 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #580 from EleutherAI/fix-tasks

fixes some minor issues on tasks. 
parents 6d355b85 ced29179
...@@ -31,7 +31,7 @@ def get_model(model_name): ...@@ -31,7 +31,7 @@ def get_model(model_name):
TASK_REGISTRY = {} TASK_REGISTRY = {}
GROUP_REGISTRY = {} GROUP_REGISTRY = {}
ALL_TASKS = [] ALL_TASKS = set()
func2task_index = {} func2task_index = {}
...@@ -42,6 +42,7 @@ def register_task(name): ...@@ -42,6 +42,7 @@ def register_task(name):
), f"task named '{name}' conflicts with existing registered task!" ), f"task named '{name}' conflicts with existing registered task!"
TASK_REGISTRY[name] = fn TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name func2task_index[fn.__name__] = name
return fn return fn
...@@ -55,6 +56,7 @@ def register_group(name): ...@@ -55,6 +56,7 @@ def register_group(name):
GROUP_REGISTRY[name].append(func_name) GROUP_REGISTRY[name].append(func_name)
else: else:
GROUP_REGISTRY[name] = [func_name] GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn return fn
return decorate return decorate
......
...@@ -98,7 +98,9 @@ class TaskConfig(dict): ...@@ -98,7 +98,9 @@ class TaskConfig(dict):
self.gold_alias = self.template_aliases + self.doc_to_target self.gold_alias = self.template_aliases + self.doc_to_target
if self.generation_kwargs or self.output_type == "greedy_until": if self.generation_kwargs or self.output_type == "greedy_until":
assert self.output_type == "greedy_until", "passed `generation_kwargs`, but not using a generation request type!" assert (
self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!"
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0} self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
...@@ -460,7 +462,7 @@ class Task(abc.ABC): ...@@ -460,7 +462,7 @@ class Task(abc.ABC):
return self._instances return self._instances
def dump_config(self): def dump_config(self):
"""Returns a dictionary representing the task's config. """Returns a dictionary representing the task's config.
:returns: str :returns: str
The fewshot context. The fewshot context.
...@@ -532,7 +534,7 @@ class ConfigurableTask(Task): ...@@ -532,7 +534,7 @@ class ConfigurableTask(Task):
} }
try: try:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
except: except Exception:
eval_logger.warning( eval_logger.warning(
f"Metric {metric_name} not found, " f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric" "Searching from https://huggingface.co/evaluate-metric"
...@@ -550,15 +552,24 @@ class ConfigurableTask(Task): ...@@ -550,15 +552,24 @@ class ConfigurableTask(Task):
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[agg_name] if type(agg_name) == str:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
agg_name
]
elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else: else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = DEFAULT_AGGREGATION_REGISTRY[metric_name]
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not" f"metric {metric_name} is defined, but aggregation is not. "
f"using default aggregation for {metric_name}" f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
) )
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[ self._aggregation_list[metric_name] = metric_agg
metric_name
]
if "higher_is_better" in metric_config: if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[ self._higher_is_better[metric_name] = metric_config[
...@@ -566,8 +577,9 @@ class ConfigurableTask(Task): ...@@ -566,8 +577,9 @@ class ConfigurableTask(Task):
] ]
else: else:
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not" f"metric {metric_name} is defined, but higher_is_better is not. "
f"using default higher_is_better for {metric_name}" f"using default "
f"higher_is_better={HIGHER_IS_BETTER_REGISTRY[metric_name]}"
) )
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[ self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name metric_name
...@@ -592,9 +604,7 @@ class ConfigurableTask(Task): ...@@ -592,9 +604,7 @@ class ConfigurableTask(Task):
filter_pipeline = build_filter_ensemble(filter_name, components) filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
else: else:
self._filters = [ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
build_filter_ensemble("none", [["take_first", None]])
]
if self._config.use_prompt is not None: if self._config.use_prompt is not None:
eval_logger.info(f"loading prompt {self._config.use_prompt}") eval_logger.info(f"loading prompt {self._config.use_prompt}")
......
...@@ -12,6 +12,7 @@ from lm_eval.api.registry import ( ...@@ -12,6 +12,7 @@ from lm_eval.api.registry import (
register_group, register_group,
TASK_REGISTRY, TASK_REGISTRY,
GROUP_REGISTRY, GROUP_REGISTRY,
ALL_TASKS,
) )
...@@ -38,6 +39,9 @@ def include_task_folder(task_dir): ...@@ -38,6 +39,9 @@ def include_task_folder(task_dir):
) )
if "task" in config: if "task" in config:
# task_name = "{}:{}".format(
# get_task_name_from_config(config), config["task"]
# )
task_name = "{}".format(config["task"]) task_name = "{}".format(config["task"])
register_task(task_name)(SubClass) register_task(task_name)(SubClass)
......
...@@ -4,11 +4,10 @@ import fnmatch ...@@ -4,11 +4,10 @@ import fnmatch
import argparse import argparse
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.registry import GROUP_REGISTRY, TASK_REGISTRY from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
...@@ -20,9 +19,8 @@ class MultiChoice: ...@@ -20,9 +19,8 @@ class MultiChoice:
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value)) eval_logger.warning("{} is not in task list.".format(value))
eval_logger.info(f"Available tasks to choose:") eval_logger.info(f"Available tasks to choose:")
# for choice in self.choices: for choice in self.choices:
# eval_logger.info(f" {choice}") eval_logger.info(f" - {choice}")
eval_logger.info(ALL_TASKS)
return True return True
def __iter__(self): def __iter__(self):
...@@ -34,7 +32,7 @@ def parse_args(): ...@@ -34,7 +32,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True) parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="") parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=MultiChoice(ALL_TASKS)) parser.add_argument("--tasks", default=None, choices=MultiChoice(sorted(ALL_TASKS)))
parser.add_argument("--config", default=None) parser.add_argument("--config", default=None)
parser.add_argument("--provide_description", action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment