Commit 703e0d55 authored by lintangsutawika's avatar lintangsutawika
Browse files

adjusted aggregation config

parent 787b23f6
......@@ -577,22 +577,25 @@ class ConfigurableTask(Task):
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
not in ["metric", "aggregation", "higher_is_better", "use_hf_evaluate"]
}
use_hf_evaluate = (
"hf_evaluate" in metric_config
and metric_config["hf_evaluate"] is True
"use_hf_evaluate" in metric_config
and metric_config["use_hf_evaluate"] is True
)
if callable(metric_name):
if self.config.process_results is not None:
metric_fn = None
kwargs = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
else:
assert type(metric_name) == str
use_metric_for_agg = True
if use_hf_evaluate:
metric_fn = get_evaluate(metric_name, **kwargs)
elif metric_name in METRIC_REGISTRY:
from_registry = True
metric = get_metric(metric_name, **kwargs)
metric_fn = metric["function"]
......@@ -606,19 +609,17 @@ class ConfigurableTask(Task):
elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = agg_name
else:
if from_registry:
if "aggregation" in metric:
self._aggregation_list[metric_name] = metric["aggregation"]
else:
self._aggregation_list[metric_name] = metric_fn
if use_hf_evaluate:
self._aggregation_list[metric_name] = metric_fn
elif (metric_name in METRIC_REGISTRY) and ("aggregation" in metric):
self._aggregation_list[metric_name] = metric["aggregation"]
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
if from_registry:
self._higher_is_better[metric_name] = metric["higher_is_better"]
self._higher_is_better[metric_name] = metric["higher_is_better"]
self.download(self.config.dataset_kwargs)
self._training_docs = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment