Commit b1b5239d authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

address PR comments

parent 6e3ef5ff
......@@ -80,7 +80,7 @@ DEFAULT_METRIC_REGISTRY = {
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"winograd_schema": ["acc", "acc_norm"],
"winograd_schema": ["acc"],
"greedy_until": ["exact_match"],
}
......
......@@ -305,6 +305,8 @@ class Task(abc.ABC):
self._config.template_aliases + "{{answer_choices}}", doc
)
)
elif type(self._config.create_choices) == str:
return utils.apply_template(self._config.create_choices, doc)
else:
return self._config.create_choices(doc)
......@@ -813,26 +815,7 @@ class ConfigurableTask(Task):
)
for i, context in enumerate(contexts)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list
return Instance(
......@@ -933,9 +916,6 @@ class ConfigurableTask(Task):
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "greedy_until":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment