"composable_kernel/include/utility/static_buffer.hpp" did not exist on "01055d95d9df0fd79eeaa02336593d6432b09a7f"
Commit 8d4d1fa9 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed registered metric

parent 9fbe6eef
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
import random import random
import evaluate
from lm_eval.api.registry import register_metric, register_aggregation from lm_eval.api.registry import register_metric, register_aggregation
...@@ -141,8 +142,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -141,8 +142,8 @@ def acc_mutual_info_fn(items): # This is a passthrough function
output_type="generate_until", output_type="generate_until",
aggregation="mean", aggregation="mean",
) )
def exact_match_fn(items): # This is a passthrough function def exact_match_fn(**kwargs): # This is a passthrough function
return items return evaluate.load("exact_match").compute(**kwargs)
@register_metric( @register_metric(
......
...@@ -544,6 +544,7 @@ class ConfigurableTask(Task): ...@@ -544,6 +544,7 @@ class ConfigurableTask(Task):
for metric_name in _metric_list: for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name) self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation( self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name metric_name
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment