Commit a2e41158 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed acc_mutula_info calc position

parent 424a4280
......@@ -500,10 +500,7 @@ class ConfigurableTask(Task):
_metric_list = DEFAULT_METRIC_REGISTRY[self._config.output_type]
if self._config.metric_list is None:
eval_logger.warning(
f"Output Type set as {self._config.output_type} and metric_list is not set"
"Will default to exact_match"
)
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[
......@@ -799,6 +796,16 @@ class ConfigurableTask(Task):
self._config.template_aliases + "{{answer_choices}}", doc
)
)
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_list.keys()
):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in choices])
......@@ -817,14 +824,6 @@ class ConfigurableTask(Task):
result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in use_metric:
if 2 * len(choices) == len(lls):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment