Commit 20c10dfe authored by lintangsutawika's avatar lintangsutawika
Browse files

kwargs are added to metric_fn through partial at the beginning

parent e5b245cc
......@@ -17,6 +17,7 @@ import numpy as np
from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable
from functools import partial
from lm_eval import utils
from lm_eval.api import samplers
......@@ -553,7 +554,6 @@ class ConfigurableTask(Task):
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self.config.metric_list is None:
......@@ -580,19 +580,19 @@ class ConfigurableTask(Task):
and metric_config["hf_evaluate"] is True
)
if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
# if self.config.process_results is not None:
# self._metric_fn_list[metric_name] = None
# self._metric_fn_kwargs[metric_name] = {}
if callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = get_metric(
metric_fn = get_metric(
metric_name, hf_evaluate_metric
)
self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = partial(metric_fn, **kwargs) if kwargs != {} else metric_fn
self.download(self.config.dataset_kwargs)
self._training_docs = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment