Commit c746d1fb authored by lintangsutawika's avatar lintangsutawika
Browse files

fixing metric for each output_type

parent a339ffd8
......@@ -485,36 +485,64 @@ class ConfigurableTask(Task):
self._metric_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
for metric_config in self._config.metric_list:
metric_name = metric_config["metric"]
aggregation = metric_config["aggregation"]
higher_is_better = metric_config["higher_is_better"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
if self._config.output_type != "greedy_util":
eval_logger.warning(
f"Output Type set as {self._config.output_type} which does not use metric_list"
"metric list will be unused."
)
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
if self._config.output_type == "loglikelihood":
metric_list = ["perplexity", "acc"]
elif self._config.output_type == "loglikelihood_rolling":
metric_list = [
"word_perplexity",
"byte_perplexity",
"bits_per_byte",
]
elif self._config.output_type == "multiple_choice":
metric_list = ["acc", "acc_norm"]
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
for metric_name in metric_list:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY["mean"]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
self._higher_is_better[metric_name] = higher_is_better
try:
metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
else:
for metric_config in self._config.metric_list:
metric_name = metric_config["metric"]
aggregation = metric_config["aggregation"]
higher_is_better = metric_config["higher_is_better"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[
aggregation
]
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
self._higher_is_better[metric_name] = higher_is_better
try:
metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(
metric_name
),
"Please check https://huggingface.co/evaluate-metric",
)
self.download(self._config.dataset_kwargs)
self._training_docs = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment