"driver/include/conv_common.hpp" did not exist on "84d9802d30de16795e63a8625098634527c80ae4"
Commit a8396b2c authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

add new task type: winograd_schema

parent d674c7bd
......@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = {
],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"winograd_schema": ["acc", "acc_norm"],
"greedy_until": ["exact_match"],
}
......
......@@ -44,6 +44,7 @@ ALL_OUTPUT_TYPES = [
"multiple_choice",
"loglikelihood_rolling",
"greedy_until",
"winograd_schema"
]
......@@ -75,6 +76,7 @@ class TaskConfig(dict):
metric_list: str = None
gold_alias: Union[Callable, str] = None
create_choices: Union[Callable, str] = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
filter_list: Union[str, list] = None
......@@ -295,6 +297,16 @@ class Task(abc.ABC):
The processed version of the specified `doc`.
"""
return doc
def create_choices(self, doc):
if self._config.create_choices is None:
return ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
else:
return self._config.create_choices(doc)
@property
def instances(self):
......@@ -746,11 +758,8 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this?
choices = ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
choices = self.create_choices(doc)
request_list = [
Instance(
request_type="loglikelihood",
......@@ -786,6 +795,45 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.generation_kwargs)
elif self.OUTPUT_TYPE == "winograd_schema":
# similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation
contexts = self.create_choices(doc)
choice = self.doc_to_target(doc)
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(context, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, context in enumerate(contexts)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list
return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
)
......@@ -835,11 +883,7 @@ class ConfigurableTask(Task):
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
choices = self.create_choices(doc)
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
......@@ -875,6 +919,24 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "winograd_schema":
lls, is_greedy = zip(*results)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
acc = 1.0 if np.argmax(lls) == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None:
......@@ -893,7 +955,7 @@ class ConfigurableTask(Task):
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', 'multiple_choice' or 'winograd_schema' ",
)
return result_dict
......
......@@ -214,7 +214,7 @@ def evaluate(
# aggregate Instances by LM method requested to get output.
reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
if (task.OUTPUT_TYPE == "multiple_choice" or task.OUTPUT_TYPE == "winograd_schema")
else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
......@@ -274,7 +274,6 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
)
)
for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment