Commit 94346b7e authored by Chris's avatar Chris
Browse files

Allow forced import of metrics from the HF Evaluate library

parent 0aa37743
...@@ -117,24 +117,23 @@ def register_metric(**args): ...@@ -117,24 +117,23 @@ def register_metric(**args):
return decorate return decorate
def get_metric(name): def get_metric(name, hf_evaluate_metric=False):
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
else:
eval_logger.warning(
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
)
try: try:
return METRIC_REGISTRY[name] metric_object = evaluate.load(name)
except KeyError: return metric_object.compute
# TODO: change this print to logging? except Exception:
print( eval_logger.error(
f"Could not find registered metric '{name}' in lm-eval, \ f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
searching in HF Evaluate library..."
) )
try:
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
eval_logger.error(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
def register_aggregation(name): def register_aggregation(name):
......
...@@ -555,8 +555,9 @@ class ConfigurableTask(Task): ...@@ -555,8 +555,9 @@ class ConfigurableTask(Task):
kwargs = { kwargs = {
key: metric_config[key] key: metric_config[key]
for key in metric_config for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
} }
hf_evaluate_metric = "hf_evaluate" in metric_config and metric_config["hf_evaluate"] == True
if self.config.process_results is not None: if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None self._metric_fn_list[metric_name] = None
...@@ -567,7 +568,7 @@ class ConfigurableTask(Task): ...@@ -567,7 +568,7 @@ class ConfigurableTask(Task):
self._metric_fn_list[metric_name] = metric_fn self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
else: else:
self._metric_fn_list[metric_name] = get_metric(metric_name) self._metric_fn_list[metric_name] = get_metric(metric_name, hf_evaluate_metric)
self._metric_fn_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config: if "aggregation" in metric_config:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment