Commit c746d1fb authored by lintangsutawika's avatar lintangsutawika
Browse files

fixing metric for each output_type

parent a339ffd8
...@@ -485,6 +485,30 @@ class ConfigurableTask(Task): ...@@ -485,6 +485,30 @@ class ConfigurableTask(Task):
self._metric_kwargs = {} self._metric_kwargs = {}
self._aggregation_list = {} self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
if self._config.output_type != "greedy_util":
eval_logger.warning(
f"Output Type set as {self._config.output_type} which does not use metric_list"
"metric list will be unused."
)
if self._config.output_type == "loglikelihood":
metric_list = ["perplexity", "acc"]
elif self._config.output_type == "loglikelihood_rolling":
metric_list = [
"word_perplexity",
"byte_perplexity",
"bits_per_byte",
]
elif self._config.output_type == "multiple_choice":
metric_list = ["acc", "acc_norm"]
for metric_name in metric_list:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY["mean"]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
for metric_config in self._config.metric_list: for metric_config in self._config.metric_list:
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
...@@ -496,7 +520,9 @@ class ConfigurableTask(Task): ...@@ -496,7 +520,9 @@ class ConfigurableTask(Task):
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation] self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
aggregation
]
if metric_name in METRIC_REGISTRY.keys(): if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
...@@ -512,7 +538,9 @@ class ConfigurableTask(Task): ...@@ -512,7 +538,9 @@ class ConfigurableTask(Task):
except Exception: except Exception:
raise Warning( raise Warning(
"{} not found in the evaluate library!".format(metric_name), "{} not found in the evaluate library!".format(
metric_name
),
"Please check https://huggingface.co/evaluate-metric", "Please check https://huggingface.co/evaluate-metric",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment