Commit 2a573a19 authored by lintangsutawika's avatar lintangsutawika
Browse files

adjust to be backwards compatible

parent 703e0d55
...@@ -571,13 +571,17 @@ class ConfigurableTask(Task): ...@@ -571,13 +571,17 @@ class ConfigurableTask(Task):
else: else:
for metric_config in self.config.metric_list: for metric_config in self.config.metric_list:
assert "metric" in metric_config assert "metric" in metric_config
from_registry = False
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
kwargs = { kwargs = {
key: metric_config[key] key: metric_config[key]
for key in metric_config for key in metric_config
if key if key
not in ["metric", "aggregation", "higher_is_better", "use_hf_evaluate"] not in [
"metric",
"aggregation",
"higher_is_better",
"use_hf_evaluate",
]
} }
use_hf_evaluate = ( use_hf_evaluate = (
"use_hf_evaluate" in metric_config "use_hf_evaluate" in metric_config
...@@ -592,7 +596,6 @@ class ConfigurableTask(Task): ...@@ -592,7 +596,6 @@ class ConfigurableTask(Task):
metric_name = metric_name.__name__ metric_name = metric_name.__name__
else: else:
assert type(metric_name) == str assert type(metric_name) == str
use_metric_for_agg = True
if use_hf_evaluate: if use_hf_evaluate:
metric_fn = get_evaluate(metric_name, **kwargs) metric_fn = get_evaluate(metric_name, **kwargs)
elif metric_name in METRIC_REGISTRY: elif metric_name in METRIC_REGISTRY:
...@@ -602,16 +605,27 @@ class ConfigurableTask(Task): ...@@ -602,16 +605,27 @@ class ConfigurableTask(Task):
self._metric_fn_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
self._metric_fn_list[metric_name] = metric_fn self._metric_fn_list[metric_name] = metric_fn
# Ignores aggregation if the metric set
# is a registered metric
# for backward compatibility
if metric_name in METRIC_REGISTRY and ("aggregation" not in metric):
self._aggregation_list[metric_name] = metric_fn
else:
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
if isinstance(agg_name, str): if isinstance(agg_name, str):
self._aggregation_list[metric_name] = get_aggregation(agg_name) self._aggregation_list[metric_name] = get_aggregation(
agg_name
)
elif callable(agg_name): # noqa: E721 elif callable(agg_name): # noqa: E721
self._aggregation_list[metric_name] = agg_name self._aggregation_list[metric_name] = agg_name
else: else:
if use_hf_evaluate: if use_hf_evaluate:
self._aggregation_list[metric_name] = metric_fn self._aggregation_list[metric_name] = metric_fn
elif (metric_name in METRIC_REGISTRY) and ("aggregation" in metric): elif (metric_name in METRIC_REGISTRY) and (
"aggregation" in metric
):
self._aggregation_list[metric_name] = metric["aggregation"] self._aggregation_list[metric_name] = metric["aggregation"]
if "higher_is_better" in metric_config: if "higher_is_better" in metric_config:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment