Unverified Commit 6be66284 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'big-refactor' into more-docs

parents 400c0199 0c53ff49
......@@ -36,7 +36,7 @@ def get_model(model_name):
TASK_REGISTRY = {}
GROUP_REGISTRY = {}
ALL_TASKS = []
ALL_TASKS = set()
func2task_index = {}
......@@ -47,6 +47,7 @@ def register_task(name):
), f"task named '{name}' conflicts with existing registered task!"
TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name
return fn
......@@ -60,6 +61,7 @@ def register_group(name):
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn
return decorate
......
......@@ -435,7 +435,7 @@ class Task(abc.ABC):
if num_fewshot == 0:
labeled_examples = ""
else:
labeled_examples = self.sampler.get_context(doc, self._config.num_fewshot)
labeled_examples = self.sampler.get_context(doc, num_fewshot)
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
# if self.has_training_docs():
......@@ -566,15 +566,24 @@ class ConfigurableTask(Task):
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[agg_name]
if type(agg_name) == str:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
agg_name
]
elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = DEFAULT_AGGREGATION_REGISTRY[metric_name]
eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not"
f"using default aggregation for {metric_name}"
f"metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
......@@ -582,8 +591,9 @@ class ConfigurableTask(Task):
]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}"
f"metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={HIGHER_IS_BETTER_REGISTRY[metric_name]}"
)
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
......
......@@ -13,6 +13,7 @@ from lm_eval.api.registry import (
register_group,
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
......@@ -39,6 +40,9 @@ def include_task_folder(task_dir):
)
if "task" in config:
# task_name = "{}:{}".format(
# get_task_name_from_config(config), config["task"]
# )
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
......
......@@ -4,11 +4,10 @@ import fnmatch
import argparse
from lm_eval import evaluator, utils
from lm_eval.api.registry import GROUP_REGISTRY, TASK_REGISTRY
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
class MultiChoice:
......@@ -21,9 +20,8 @@ class MultiChoice:
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value))
eval_logger.info(f"Available tasks to choose:")
# for choice in self.choices:
# eval_logger.info(f" {choice}")
eval_logger.info(ALL_TASKS)
for choice in self.choices:
eval_logger.info(f" - {choice}")
return True
def __iter__(self):
......@@ -35,7 +33,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default=None, choices=MultiChoice(ALL_TASKS))
parser.add_argument("--tasks", default=None, choices=MultiChoice(sorted(ALL_TASKS)))
parser.add_argument("--config", default=None)
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment