Commit 028f04c7 authored by lintangsutawika's avatar lintangsutawika
Browse files

loglikelihood and loglikelihood rolling modified

parent 1d262a59
...@@ -566,11 +566,16 @@ class ConfigurableTask(Task): ...@@ -566,11 +566,16 @@ class ConfigurableTask(Task):
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type] _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list: for metric_name in _metric_list:
self._metric_fn_list[metric_name] = get_metric(metric_name) metric = get_metric(metric_name)
self._metric_fn_list[metric_name] = metric
self._metric_fn_kwargs[metric_name] = {} self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation( self._aggregation_list[metric_name] = metric.aggregation
metric_name # try:
) # self._aggregation_list[metric_name] = metric.aggregation
# except:
# self._aggregation_list[metric_name] = get_metric_aggregation(
# metric_name
# )
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
else: else:
for metric_config in self.config.metric_list: for metric_config in self.config.metric_list:
...@@ -601,35 +606,35 @@ class ConfigurableTask(Task): ...@@ -601,35 +606,35 @@ class ConfigurableTask(Task):
) )
self._metric_fn_kwargs[metric_name] = kwargs self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config: # if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] # agg_name = metric_config["aggregation"]
if type(agg_name) == str: # if type(agg_name) == str:
self._aggregation_list[metric_name] = get_aggregation(agg_name) # self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name): # elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[ # self._aggregation_list[metric_name] = metric_config[
"aggregation" # "aggregation"
] # ]
else: # else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} # INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name) # metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning( # eval_logger.warning(
f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. " # f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default " # f"using default "
f"aggregation={INV_AGG_REGISTRY[metric_agg]}" # f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
) # )
self._aggregation_list[metric_name] = metric_agg # self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config: # if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config[ # self._higher_is_better[metric_name] = metric_config[
"higher_is_better" # "higher_is_better"
] # ]
else: # else:
eval_logger.warning( # eval_logger.warning(
f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. " # f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default " # f"using default "
f"higher_is_better={is_higher_better(metric_name)}" # f"higher_is_better={is_higher_better(metric_name)}"
) # )
self._higher_is_better[metric_name] = is_higher_better(metric_name) # self._higher_is_better[metric_name] = is_higher_better(metric_name)
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
...@@ -1022,35 +1027,15 @@ class ConfigurableTask(Task): ...@@ -1022,35 +1027,15 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(self._metric_fn_list.keys())
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
return { return ll, is_greedy
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc)) _words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc)) _bytes = self.count_bytes(self.doc_to_target(doc))
return { return loglikelihood, _words, _bytes
**(
{"word_perplexity": (loglikelihood, _words)}
if "word_perplexity" in use_metric
else {}
),
**(
{"byte_perplexity": (loglikelihood, _bytes)}
if "byte_perplexity" in use_metric
else {}
),
**(
{"bits_per_byte": (loglikelihood, _bytes)}
if "bits_per_byte" in use_metric
else {}
),
}
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
...@@ -1192,7 +1177,8 @@ class ConfigurableTask(Task): ...@@ -1192,7 +1177,8 @@ class ConfigurableTask(Task):
return result_dict return result_dict
def aggregation(self): def aggregation(self):
return self._aggregation_list # return self._aggregation_list
return self._metric_fn_list
def higher_is_better(self): def higher_is_better(self):
return self._higher_is_better return self._higher_is_better
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment