"magic_pdf/vscode:/vscode.git/clone" did not exist on "c7a685b302ceb161e1e60813457575f42d7a66ab"
Commit a8396b2c authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

add new task type: winograd_schema

parent d674c7bd
...@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = {
], ],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"], "multiple_choice": ["acc", "acc_norm"],
"winograd_schema": ["acc", "acc_norm"],
"greedy_until": ["exact_match"], "greedy_until": ["exact_match"],
} }
......
...@@ -44,6 +44,7 @@ ALL_OUTPUT_TYPES = [ ...@@ -44,6 +44,7 @@ ALL_OUTPUT_TYPES = [
"multiple_choice", "multiple_choice",
"loglikelihood_rolling", "loglikelihood_rolling",
"greedy_until", "greedy_until",
"winograd_schema"
] ]
...@@ -75,6 +76,7 @@ class TaskConfig(dict): ...@@ -75,6 +76,7 @@ class TaskConfig(dict):
metric_list: str = None metric_list: str = None
gold_alias: Union[Callable, str] = None gold_alias: Union[Callable, str] = None
create_choices: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
...@@ -295,6 +297,16 @@ class Task(abc.ABC): ...@@ -295,6 +297,16 @@ class Task(abc.ABC):
The processed version of the specified `doc`. The processed version of the specified `doc`.
""" """
return doc return doc
def create_choices(self, doc):
if self._config.create_choices is None:
return ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
else:
return self._config.create_choices(doc)
@property @property
def instances(self): def instances(self):
...@@ -746,11 +758,8 @@ class ConfigurableTask(Task): ...@@ -746,11 +758,8 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list. # we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this? # TODO: any cleaner way to do this?
choices = ast.literal_eval( choices = self.create_choices(doc)
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
...@@ -786,6 +795,45 @@ class ConfigurableTask(Task): ...@@ -786,6 +795,45 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.generation_kwargs) arguments = (ctx, self._config.generation_kwargs)
elif self.OUTPUT_TYPE == "winograd_schema":
# similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation
contexts = self.create_choices(doc)
choice = self.doc_to_target(doc)
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(context, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, context in enumerate(contexts)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_fn_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
]
)
return request_list
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
) )
...@@ -835,11 +883,7 @@ class ConfigurableTask(Task): ...@@ -835,11 +883,7 @@ class ConfigurableTask(Task):
pred = np.argmax(lls) pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = self.create_choices(doc)
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys() and "acc_mutual_info" in self._metric_fn_list.keys()
...@@ -875,6 +919,24 @@ class ConfigurableTask(Task): ...@@ -875,6 +919,24 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "winograd_schema":
lls, is_greedy = zip(*results)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
acc = 1.0 if np.argmax(lls) == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
if self._config.gold_alias is not None: if self._config.gold_alias is not None:
...@@ -893,7 +955,7 @@ class ConfigurableTask(Task): ...@@ -893,7 +955,7 @@ class ConfigurableTask(Task):
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', or 'multiple_choice'", "'loglikelihood', 'loglikelihood_rolling', 'greedy_until', 'multiple_choice' or 'winograd_schema' ",
) )
return result_dict return result_dict
......
...@@ -214,7 +214,7 @@ def evaluate( ...@@ -214,7 +214,7 @@ def evaluate(
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = ( reqtype = (
"loglikelihood" "loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice" if (task.OUTPUT_TYPE == "multiple_choice" or task.OUTPUT_TYPE == "winograd_schema")
else task.OUTPUT_TYPE else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py ) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances) requests[reqtype].extend(task.instances)
...@@ -274,7 +274,6 @@ def evaluate( ...@@ -274,7 +274,6 @@ def evaluate(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
) )
) )
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment