Commit 1813bf04 authored by lintangsutawika's avatar lintangsutawika
Browse files

skips metrics prep if process_result is not None

parent 0eb94c8b
import os import os
import evaluate import evaluate
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.logger import eval_logger
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
...@@ -131,7 +132,7 @@ searching in HF Evaluate library..." ...@@ -131,7 +132,7 @@ searching in HF Evaluate library..."
metric_object = evaluate.load(name) metric_object = evaluate.load(name)
return metric_object.compute return metric_object.compute
except Exception: except Exception:
raise Warning( eval_logger.error(
"{} not found in the evaluate library!".format(name), "{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric", "Please check https://huggingface.co/evaluate-metric",
) )
...@@ -154,7 +155,7 @@ def get_aggregation(name): ...@@ -154,7 +155,7 @@ def get_aggregation(name):
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
raise Warning( eval_logger.warning(
"{} not a registered aggregation metric!".format(name), "{} not a registered aggregation metric!".format(name),
) )
...@@ -163,7 +164,9 @@ def get_default_aggregation(metric_name): ...@@ -163,7 +164,9 @@ def get_default_aggregation(metric_name):
try: try:
return DEFAULT_AGGREGATION_REGISTRY[metric_name] return DEFAULT_AGGREGATION_REGISTRY[metric_name]
except KeyError: except KeyError:
raise Warning(f"No default aggregation metric for metric '{metric_name}'!") eval_logger.warning(
f"No default aggregation metric for metric '{metric_name}'!"
)
def is_higher_better(metric_name): def is_higher_better(metric_name):
...@@ -171,3 +174,6 @@ def is_higher_better(metric_name): ...@@ -171,3 +174,6 @@ def is_higher_better(metric_name):
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
raise Warning(f"higher_is_better not specified for metric '{metric_name}'!") raise Warning(f"higher_is_better not specified for metric '{metric_name}'!")
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
)
...@@ -554,8 +554,13 @@ class ConfigurableTask(Task): ...@@ -554,8 +554,13 @@ class ConfigurableTask(Task):
for key in metric_config for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = kwargs if self._config.process_results is None:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment