"...composable_kernel_rocm.git" did not exist on "9fb64dae572f774d88329fe9aca6d809928ba9db"
Commit 20c10dfe authored by lintangsutawika's avatar lintangsutawika
Browse files

kwargs are added to metric_fn through partial at the beginning

parent e5b245cc
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from typing import Union, List, Any, Tuple, Literal from typing import Union, List, Any, Tuple, Literal
from collections.abc import Callable from collections.abc import Callable
from functools import partial
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
...@@ -553,7 +554,6 @@ class ConfigurableTask(Task): ...@@ -553,7 +554,6 @@ class ConfigurableTask(Task):
self._metric_fn_list = {} self._metric_fn_list = {}
self._metric_fn_kwargs = {} self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
if self.config.metric_list is None: if self.config.metric_list is None:
...@@ -580,19 +580,19 @@ class ConfigurableTask(Task): ...@@ -580,19 +580,19 @@ class ConfigurableTask(Task):
and metric_config["hf_evaluate"] is True and metric_config["hf_evaluate"] is True
) )
if self.config.process_results is not None: # if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None # self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {} # self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name): if callable(metric_name):
metric_fn = metric_name.__call__ metric_fn = metric_name.__call__
metric_name = metric_name.__name__ metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else: else:
self._metric_fn_list[metric_name] = get_metric( metric_fn = get_metric(
metric_name, hf_evaluate_metric metric_name, hf_evaluate_metric
) )
self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = partial(metric_fn, **kwargs) if kwargs != {} else metric_fn
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment