Commit 028f04c7 authored by lintangsutawika's avatar lintangsutawika
Browse files

loglikelihood and loglikelihood rolling modified

parent 1d262a59
......@@ -566,11 +566,16 @@ class ConfigurableTask(Task):
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name)
metric = get_metric(metric_name)
self._metric_fn_list[metric_name] = metric
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(
metric_name
)
self._aggregation_list[metric_name] = metric.aggregation
# try:
# self._aggregation_list[metric_name] = metric.aggregation
# except:
# self._aggregation_list[metric_name] = get_metric_aggregation(
# metric_name
# )
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
......@@ -601,35 +606,35 @@ class ConfigurableTask(Task):
)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if type(agg_name) == str:
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
)
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[
"higher_is_better"
]
else:
eval_logger.warning(
f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
# if "aggregation" in metric_config:
# agg_name = metric_config["aggregation"]
# if type(agg_name) == str:
# self._aggregation_list[metric_name] = get_aggregation(agg_name)
# elif callable(agg_name):
# self._aggregation_list[metric_name] = metric_config[
# "aggregation"
# ]
# else:
# INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
# metric_agg = get_metric_aggregation(metric_name)
# eval_logger.warning(
# f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
# f"using default "
# f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
# )
# self._aggregation_list[metric_name] = metric_agg
# if "higher_is_better" in metric_config:
# self._higher_is_better[metric_name] = metric_config[
# "higher_is_better"
# ]
# else:
# eval_logger.warning(
# f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
# f"using default "
# f"higher_is_better={is_higher_better(metric_name)}"
# )
# self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self.config.dataset_kwargs)
self._training_docs = None
......@@ -1022,35 +1027,15 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
return {
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
return ll, is_greedy
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
return {
**(
{"word_perplexity": (loglikelihood, _words)}
if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
}
return loglikelihood, _words, _bytes
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
......@@ -1192,7 +1177,8 @@ class ConfigurableTask(Task):
return result_dict
def aggregation(self):
return self._aggregation_list
# return self._aggregation_list
return self._metric_fn_list
def higher_is_better(self):
return self._higher_is_better
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment