Commit 6a336b15 authored by lintangsutawika's avatar lintangsutawika
Browse files

use HFEvaluateAdaptor for hf metrics

parent 20c10dfe
......@@ -159,15 +159,29 @@ def acc_mutual_info_fn(items):
return mean(items)
exact_match = evaluate.load("exact_match")
class HFEvaluateAdaptor:
def __init__(self, *metric_args, **kwargs):
metric_object = evaluate.load(*metric_args)
self.hf_evaluate_fn = partial(metric_object, **kwargs)
def __call__(self, items):
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
return self.hf_evaluate_fn(
references=refs,
predictions=preds
)
exact_match = evaluate.load("exact_match")
@register_metric(
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
)
def exact_match_fn(**kwargs):
def hf_evaluate_fn(**kwargs):
return exact_match.compute(**kwargs)
......
import os
import evaluate
from lm_eval.api.model import LM
from lm_eval.api.metrics import HFEvaluateAdaptor
import logging
eval_logger = logging.getLogger("lm-eval")
......@@ -115,7 +115,7 @@ def register_metric(
return decorate
def get_metric(name, hf_evaluate_metric=False):
def get_metric(name, hf_evaluate_metric=False, **kwargs):
if not hf_evaluate_metric:
if name in METRIC_FUNCTION_REGISTRY:
......@@ -126,8 +126,8 @@ def get_metric(name, hf_evaluate_metric=False):
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
from lm_eval.metrics import HFEvaluateAdaptor
return HFEvaluateAdaptor(name, **kwargs)
except Exception:
eval_logger.error(
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
......
......@@ -17,7 +17,6 @@ import numpy as np
from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable
from functools import partial
from lm_eval import utils
from lm_eval.api import samplers
......@@ -588,11 +587,11 @@ class ConfigurableTask(Task):
metric_name = metric_name.__name__
else:
metric_fn = get_metric(
metric_name, hf_evaluate_metric
metric_name, hf_evaluate_metric, **kwargs
)
self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = partial(metric_fn, **kwargs) if kwargs != {} else metric_fn
self._metric_fn_list[metric_name] = metric_fn
self.download(self.config.dataset_kwargs)
self._training_docs = None
......@@ -1106,6 +1105,8 @@ class ConfigurableTask(Task):
gold = type(result)(gold)
for metric in self._metric_fn_list.keys():
result_dict[metric] = (gold, result)
continue
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
......@@ -1141,7 +1142,6 @@ class ConfigurableTask(Task):
result_score = self._metric_fn_list[metric](
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment