Commit 4d8094bb authored by baberabb's avatar baberabb
Browse files

Improve error logging

parent 6769119f
......@@ -33,7 +33,6 @@ repos:
rev: 22.3.0
hooks:
- id: black
language_version: python3.8
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
......
......@@ -101,7 +101,6 @@ def parse_eval_args() -> argparse.Namespace:
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if not args:
# we allow for args to be passed externally, else we parse them ourselves
args = parse_eval_args()
......@@ -132,19 +131,21 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
else:
tasks_list = args.tasks.split(",")
task_names = utils.pattern_match(tasks_list, ALL_TASKS)
task_missing = []
for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
task_names.append(config)
if task_missing != []:
missing = ", ".join(task_missing)
eval_logger.error(
f"Tasks were not found: {missing}\n"
f"{SPACING}Try `lm-eval -h` for list of available tasks",
)
raise ValueError(f"Tasks {missing} were not found.")
task_missing = [task for task in tasks_list if task not in task_names]
if task_missing:
missing = ", ".join(task_missing)
eval_logger.error(
f"Tasks were not found: {missing}\n"
f"{SPACING}Try `lm-eval -h` for list of available tasks",
)
raise ValueError(
f"Tasks {missing} were not found. Try `lm-eval -h` for list of available tasks."
)
if args.output_path:
path = Path(args.output_path)
......
......@@ -99,7 +99,7 @@ class TaskConfig(dict):
if self.generation_kwargs is not None:
if self.output_type != "greedy_until":
eval_logger.warning(
"passed `generation_kwargs`, but not using `output_type: greedy_until`!"
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: greedy_until`!"
)
assert self.output_type != "greedy_until"
......@@ -759,7 +759,6 @@ class ConfigurableTask(Task):
return super().fewshot_docs()
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances, self.task_docs)
......@@ -967,7 +966,6 @@ class ConfigurableTask(Task):
)
def process_results(self, doc, results):
if callable(self.config.process_results):
return self.config.process_results(doc, results)
......@@ -1104,7 +1102,9 @@ class ConfigurableTask(Task):
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # TODO: this is hacky and I don't want to do it
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = self._metric_fn_list[metric](
[gold_option, result]
)
......@@ -1123,7 +1123,9 @@ class ConfigurableTask(Task):
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
except (
TypeError
): # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
......
......@@ -27,7 +27,9 @@ def register_configurable_task(config: Dict[str, str]) -> int:
register_task(task_name)(SubClass)
if "group" in config:
if type(config["group"]) == str:
if config["group"] == config["task"]:
raise ValueError("task and group name cannot be the same")
elif type(config["group"]) == str:
group_name = [config["group"]]
else:
group_name = config["group"]
......@@ -45,7 +47,6 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list:
task_config = utils.load_yaml_config(yaml_path, task_config)
var_configs = check_prompt_config(
{
......@@ -137,7 +138,10 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
else:
if type(config["task"]) == list:
register_configurable_group(config, yaml_path)
except ModuleNotFoundError as e:
eval_logger.warning(
f"{yaml_path}: {e}. Config will not be added to registry."
)
except Exception as error:
import traceback
......@@ -187,7 +191,6 @@ def get_task_name_from_object(task_object):
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
config = {**kwargs}
task_name_from_registry_dict = {}
......@@ -199,7 +202,6 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
for task_element in task_name_list:
if isinstance(task_element, str):
if task_element in GROUP_REGISTRY:
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]:
......@@ -237,7 +239,6 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
}
elif isinstance(task_element, Task):
task_name_from_object_dict = {
**task_name_from_object_dict,
get_task_name_from_object(task_element): task_element,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment