Commit 6a336b15 authored by lintangsutawika's avatar lintangsutawika
Browse files

use HFEvaluateAdaptor for hf metrics

parent 20c10dfe
...@@ -159,15 +159,29 @@ def acc_mutual_info_fn(items): ...@@ -159,15 +159,29 @@ def acc_mutual_info_fn(items):
return mean(items) return mean(items)
exact_match = evaluate.load("exact_match") class HFEvaluateAdaptor:
def __init__(self, *metric_args, **kwargs):
metric_object = evaluate.load(*metric_args)
self.hf_evaluate_fn = partial(metric_object, **kwargs)
def __call__(self, items):
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
return self.hf_evaluate_fn(
references=refs,
predictions=preds
)
exact_match = evaluate.load("exact_match")
@register_metric( @register_metric(
metric="exact_match", metric="exact_match",
higher_is_better=True, higher_is_better=True,
output_type="generate_until", output_type="generate_until",
) )
def exact_match_fn(**kwargs): def hf_evaluate_fn(**kwargs):
return exact_match.compute(**kwargs) return exact_match.compute(**kwargs)
......
import os import os
import evaluate import evaluate
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.metrics import HFEvaluateAdaptor
import logging import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
...@@ -115,7 +115,7 @@ def register_metric( ...@@ -115,7 +115,7 @@ def register_metric(
return decorate return decorate
def get_metric(name, hf_evaluate_metric=False): def get_metric(name, hf_evaluate_metric=False, **kwargs):
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_FUNCTION_REGISTRY: if name in METRIC_FUNCTION_REGISTRY:
...@@ -126,8 +126,8 @@ def get_metric(name, hf_evaluate_metric=False): ...@@ -126,8 +126,8 @@ def get_metric(name, hf_evaluate_metric=False):
) )
try: try:
metric_object = evaluate.load(name) from lm_eval.metrics import HFEvaluateAdaptor
return metric_object.compute return HFEvaluateAdaptor(name, **kwargs)
except Exception: except Exception:
eval_logger.error( eval_logger.error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric", f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
......
...@@ -17,7 +17,6 @@ import numpy as np ...@@ -17,7 +17,6 @@ import numpy as np
from typing import Union, List, Any, Tuple, Literal from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable from collections.abc import Callable
from functools import partial
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
...@@ -588,11 +587,11 @@ class ConfigurableTask(Task): ...@@ -588,11 +587,11 @@ class ConfigurableTask(Task):
metric_name = metric_name.__name__ metric_name = metric_name.__name__
else: else:
metric_fn = get_metric( metric_fn = get_metric(
metric_name, hf_evaluate_metric metric_name, hf_evaluate_metric, **kwargs
) )
self._metric_fn_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = partial(metric_fn, **kwargs) if kwargs != {} else metric_fn self._metric_fn_list[metric_name] = metric_fn
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
...@@ -1106,6 +1105,8 @@ class ConfigurableTask(Task): ...@@ -1106,6 +1105,8 @@ class ConfigurableTask(Task):
gold = type(result)(gold) gold = type(result)(gold)
for metric in self._metric_fn_list.keys(): for metric in self._metric_fn_list.keys():
result_dict[metric] = (gold, result)
continue
if self.multiple_target: if self.multiple_target:
# in the case where we have multiple targets, # in the case where we have multiple targets,
# return true if any are true # return true if any are true
...@@ -1141,7 +1142,6 @@ class ConfigurableTask(Task): ...@@ -1141,7 +1142,6 @@ class ConfigurableTask(Task):
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
references=[gold], references=[gold],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric],
) )
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result]) result_score = self._metric_fn_list[metric]([gold, result])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment