Commit 0a1ced22 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixes some minor issues on tasks. For yaml, the task name should be...

fixes some minor issues on tasks. For yaml, the task name should be <dataset_path>_<dataset_name>:<task>
parent 01cfb2ff
...@@ -98,7 +98,9 @@ class TaskConfig(dict): ...@@ -98,7 +98,9 @@ class TaskConfig(dict):
self.gold_alias = self.template_aliases + self.doc_to_target self.gold_alias = self.template_aliases + self.doc_to_target
if self.generation_kwargs or self.output_type == "greedy_until": if self.generation_kwargs or self.output_type == "greedy_until":
assert self.output_type == "greedy_until", "passed `generation_kwargs`, but not using a generation request type!" assert (
self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!"
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0} self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
...@@ -532,7 +534,7 @@ class ConfigurableTask(Task): ...@@ -532,7 +534,7 @@ class ConfigurableTask(Task):
} }
try: try:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
except: except Exception:
eval_logger.warning( eval_logger.warning(
f"Metric {metric_name} not found, " f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric" "Searching from https://huggingface.co/evaluate-metric"
...@@ -550,15 +552,24 @@ class ConfigurableTask(Task): ...@@ -550,15 +552,24 @@ class ConfigurableTask(Task):
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[agg_name] if type(agg_name) == str:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
agg_name
]
elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else: else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = DEFAULT_AGGREGATION_REGISTRY[metric_name]
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not" f"metric {metric_name} is defined, but aggregation is not. "
f"using default aggregation for {metric_name}" f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
) )
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[ self._aggregation_list[metric_name] = metric_agg
metric_name
]
if "higher_is_better" in metric_config: if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[ self._higher_is_better[metric_name] = metric_config[
...@@ -566,8 +577,9 @@ class ConfigurableTask(Task): ...@@ -566,8 +577,9 @@ class ConfigurableTask(Task):
] ]
else: else:
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not" f"metric {metric_name} is defined, but higher_is_better is not. "
f"using default higher_is_better for {metric_name}" f"using default "
f"higher_is_better={HIGHER_IS_BETTER_REGISTRY[metric_name]}"
) )
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[ self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name metric_name
...@@ -592,9 +604,7 @@ class ConfigurableTask(Task): ...@@ -592,9 +604,7 @@ class ConfigurableTask(Task):
filter_pipeline = build_filter_ensemble(filter_name, components) filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
else: else:
self._filters = [ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
build_filter_ensemble("none", [["take_first", None]])
]
if self._config.use_prompt is not None: if self._config.use_prompt is not None:
eval_logger.info(f"loading prompt {self._config.use_prompt}") eval_logger.info(f"loading prompt {self._config.use_prompt}")
......
...@@ -38,7 +38,9 @@ def include_task_folder(task_dir): ...@@ -38,7 +38,9 @@ def include_task_folder(task_dir):
) )
if "task" in config: if "task" in config:
task_name = "{}".format(config["task"]) task_name = "{}:{}".format(
get_task_name_from_config(config), config["task"]
)
register_task(task_name)(SubClass) register_task(task_name)(SubClass)
if "group" in config: if "group" in config:
...@@ -56,6 +58,8 @@ def include_task_folder(task_dir): ...@@ -56,6 +58,8 @@ def include_task_folder(task_dir):
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir) include_task_folder(task_dir)
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
def get_task(task_name, config): def get_task(task_name, config):
try: try:
......
...@@ -10,6 +10,7 @@ from lm_eval.logger import eval_logger ...@@ -10,6 +10,7 @@ from lm_eval.logger import eval_logger
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys())) ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices):
self.choices = choices self.choices = choices
...@@ -20,9 +21,8 @@ class MultiChoice: ...@@ -20,9 +21,8 @@ class MultiChoice:
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value)) eval_logger.warning("{} is not in task list.".format(value))
eval_logger.info(f"Available tasks to choose:") eval_logger.info(f"Available tasks to choose:")
# for choice in self.choices: for choice in self.choices:
# eval_logger.info(f" {choice}") eval_logger.info(f" - {choice}")
eval_logger.info(ALL_TASKS)
return True return True
def __iter__(self): def __iter__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment