Commit 8d4d1fa9 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed registered metric

parent 9fbe6eef
......@@ -5,6 +5,7 @@ import numpy as np
import sacrebleu
import sklearn.metrics
import random
import evaluate
from lm_eval.api.registry import register_metric, register_aggregation
......@@ -141,8 +142,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(items): # This is a passthrough function
return items
def exact_match_fn(**kwargs): # This is a passthrough function
return evaluate.load("exact_match").compute(**kwargs)
@register_metric(
......
......@@ -544,6 +544,7 @@ class ConfigurableTask(Task):
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment