Commit 0a1ced22 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixes some minor issues on tasks. For yaml, the task name should be...

fixes some minor issues on tasks. For yaml, the task name should be <dataset_path>_<dataset_name>:<task>
parent 01cfb2ff
......@@ -98,7 +98,9 @@ class TaskConfig(dict):
self.gold_alias = self.template_aliases + self.doc_to_target
if self.generation_kwargs or self.output_type == "greedy_until":
assert self.output_type == "greedy_until", "passed `generation_kwargs`, but not using a generation request type!"
assert (
self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!"
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
......@@ -460,7 +462,7 @@ class Task(abc.ABC):
return self._instances
def dump_config(self):
"""Returns a dictionary representing the task's config.
"""Returns a dictionary representing the task's config.
:returns: str
The fewshot context.
......@@ -532,7 +534,7 @@ class ConfigurableTask(Task):
}
try:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
except:
except Exception:
eval_logger.warning(
f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric"
......@@ -550,15 +552,24 @@ class ConfigurableTask(Task):
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[agg_name]
if type(agg_name) == str:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
agg_name
]
elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = DEFAULT_AGGREGATION_REGISTRY[metric_name]
eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not"
f"using default aggregation for {metric_name}"
f"metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
metric_name
]
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
......@@ -566,8 +577,9 @@ class ConfigurableTask(Task):
]
else:
eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not"
f"using default higher_is_better for {metric_name}"
f"metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={HIGHER_IS_BETTER_REGISTRY[metric_name]}"
)
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
......@@ -592,9 +604,7 @@ class ConfigurableTask(Task):
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
else:
self._filters = [
build_filter_ensemble("none", [["take_first", None]])
]
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self._config.use_prompt is not None:
eval_logger.info(f"loading prompt {self._config.use_prompt}")
......
......@@ -38,7 +38,9 @@ def include_task_folder(task_dir):
)
if "task" in config:
task_name = "{}".format(config["task"])
task_name = "{}:{}".format(
get_task_name_from_config(config), config["task"]
)
register_task(task_name)(SubClass)
if "group" in config:
......@@ -56,6 +58,8 @@ def include_task_folder(task_dir):
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir)
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
def get_task(task_name, config):
try:
......
......@@ -10,6 +10,7 @@ from lm_eval.logger import eval_logger
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ALL_TASKS = sorted(list(TASK_REGISTRY.keys()) + list(GROUP_REGISTRY.keys()))
class MultiChoice:
def __init__(self, choices):
self.choices = choices
......@@ -20,9 +21,8 @@ class MultiChoice:
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value))
eval_logger.info(f"Available tasks to choose:")
# for choice in self.choices:
# eval_logger.info(f" {choice}")
eval_logger.info(ALL_TASKS)
for choice in self.choices:
eval_logger.info(f" - {choice}")
return True
def __iter__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment