Commit 94a49f70 authored by lintangsutawika's avatar lintangsutawika
Browse files

doc_to_target and doc_to_text should also accept int

parent 248b45da
...@@ -43,7 +43,6 @@ ALL_OUTPUT_TYPES = [ ...@@ -43,7 +43,6 @@ ALL_OUTPUT_TYPES = [
"multiple_choice", "multiple_choice",
"loglikelihood_rolling", "loglikelihood_rolling",
"greedy_until", "greedy_until",
"winograd_schema",
] ]
...@@ -735,7 +734,9 @@ class ConfigurableTask(Task): ...@@ -735,7 +734,9 @@ class ConfigurableTask(Task):
else: else:
doc_to_text = self._config.doc_to_text doc_to_text = self._config.doc_to_text
if type(doc_to_text) == str: if type(doc_to_text) == int:
return doc_to_text
elif type(doc_to_text) == str:
if doc_to_text in self.features: if doc_to_text in self.features:
# if self._config.doc_to_choice is not None: # if self._config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]] # return self.doc_to_choice(doc)[doc[doc_to_text]]
...@@ -763,7 +764,9 @@ class ConfigurableTask(Task): ...@@ -763,7 +764,9 @@ class ConfigurableTask(Task):
else: else:
doc_to_target = self._config.doc_to_target doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str: if type(doc_to_target) == int:
return doc_to_target
elif type(doc_to_target) == str:
if doc_to_target in self.features: if doc_to_target in self.features:
# if self._config.doc_to_choice is not None: # if self._config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]] # return self.doc_to_choice(doc)[doc[doc_to_target]]
...@@ -877,26 +880,6 @@ class ConfigurableTask(Task): ...@@ -877,26 +880,6 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
arguments = (ctx, self._config.generation_kwargs) arguments = (ctx, self._config.generation_kwargs)
elif self.OUTPUT_TYPE == "winograd_schema":
# similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation
contexts = self.doc_to_choice(doc)
choice = self.doc_to_target(doc)
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(context, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, context in enumerate(contexts)
]
return request_list
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
) )
...@@ -991,21 +974,6 @@ class ConfigurableTask(Task): ...@@ -991,21 +974,6 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "winograd_schema":
lls, is_greedy = zip(*results)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
acc = 1.0 if np.argmax(lls) == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
...@@ -1021,7 +989,7 @@ class ConfigurableTask(Task): ...@@ -1021,7 +989,7 @@ class ConfigurableTask(Task):
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', 'multiple_choice' or 'winograd_schema' ", "'loglikelihood', 'loglikelihood_rolling', 'greedy_until' or 'multiple_choice'",
) )
return result_dict return result_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment