Commit 703e0d55 authored by lintangsutawika's avatar lintangsutawika
Browse files

adjusted aggregation config

parent 787b23f6
...@@ -577,22 +577,25 @@ class ConfigurableTask(Task): ...@@ -577,22 +577,25 @@ class ConfigurableTask(Task):
key: metric_config[key] key: metric_config[key]
for key in metric_config for key in metric_config
if key if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"] not in ["metric", "aggregation", "higher_is_better", "use_hf_evaluate"]
} }
use_hf_evaluate = ( use_hf_evaluate = (
"hf_evaluate" in metric_config "use_hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True and metric_config["use_hf_evaluate"] is True
) )
if callable(metric_name): if self.config.process_results is not None:
metric_fn = None
kwargs = {}
elif callable(metric_name):
metric_fn = metric_name.__call__ metric_fn = metric_name.__call__
metric_name = metric_name.__name__ metric_name = metric_name.__name__
else: else:
assert type(metric_name) == str assert type(metric_name) == str
use_metric_for_agg = True
if use_hf_evaluate: if use_hf_evaluate:
metric_fn = get_evaluate(metric_name, **kwargs) metric_fn = get_evaluate(metric_name, **kwargs)
elif metric_name in METRIC_REGISTRY: elif metric_name in METRIC_REGISTRY:
from_registry = True
metric = get_metric(metric_name, **kwargs) metric = get_metric(metric_name, **kwargs)
metric_fn = metric["function"] metric_fn = metric["function"]
...@@ -606,19 +609,17 @@ class ConfigurableTask(Task): ...@@ -606,19 +609,17 @@ class ConfigurableTask(Task):
elif callable(agg_name): # noqa: E721 elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = agg_name self._aggregation_list[metric_name] = agg_name
else: else:
if from_registry: if use_hf_evaluate:
if "aggregation" in metric: self._aggregation_list[metric_name] = metric_fn
self._aggregation_list[metric_name] = metric["aggregation"] elif (metric_name in METRIC_REGISTRY) and ("aggregation" in metric):
else: self._aggregation_list[metric_name] = metric["aggregation"]
self._aggregation_list[metric_name] = metric_fn
if "higher_is_better" in metric_config: if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[ self._higher_is_better[metric_name] = metric_config[
"higher_is_better" "higher_is_better"
] ]
else: else:
if from_registry: self._higher_is_better[metric_name] = metric["higher_is_better"]
self._higher_is_better[metric_name] = metric["higher_is_better"]
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment