Unverified Commit 3aeb95b6 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Very rough draft and scribbles

parent 884ba785
...@@ -628,6 +628,16 @@ class ConfigurableTask(Task): ...@@ -628,6 +628,16 @@ class ConfigurableTask(Task):
list(self.fewshot_docs()), self, rnd=random.Random() list(self.fewshot_docs()), self, rnd=random.Random()
) )
# Test One Doc
text_output = self.doc_to_text()
if type(text_output) is list:
self.multiple_input = True
else
self.multiple_input = False
doc_to_target_output = self.doc_to_target()
doc_to_choice_output = self.doc_choice()
def download(self, dataset_kwargs=None): def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
...@@ -707,7 +717,8 @@ class ConfigurableTask(Task): ...@@ -707,7 +717,8 @@ class ConfigurableTask(Task):
return utils.apply_template(doc_to_text, doc) return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text): elif callable(doc_to_text):
return doc_to_text(doc) return doc_to_text(doc)
if hasattr(doc_to_text, "apply"): # Used when applyting a Promptsource template
elif hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0] return doc_to_text.apply(doc)[0]
else: else:
print(type(doc_to_text)) print(type(doc_to_text))
...@@ -724,6 +735,7 @@ class ConfigurableTask(Task): ...@@ -724,6 +735,7 @@ class ConfigurableTask(Task):
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) return doc_to_target(doc)
# Used when applyting a Promptsource template
elif hasattr(doc_to_target, "apply"): elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1] return doc_to_target.apply(doc)[1]
else: else:
...@@ -757,9 +769,13 @@ class ConfigurableTask(Task): ...@@ -757,9 +769,13 @@ class ConfigurableTask(Task):
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list. # we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this? # TODO: any cleaner way to do this?
choices = self.create_choices(doc) if self.multiple_input:
choices = self.doc_to_text(doc)
continuation = self.doc_to_target(doc)
else:
continuation = self.create_choices(doc)
request_list = [ request_list = [
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
...@@ -857,13 +873,11 @@ class ConfigurableTask(Task): ...@@ -857,13 +873,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = self.create_choices(doc) choices = self.create_choices(doc)
completion_len = np.array([float(len(i)) for i in choices])
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys() and "acc_mutual_info" in self._metric_fn_list.keys()
...@@ -875,11 +889,17 @@ class ConfigurableTask(Task): ...@@ -875,11 +889,17 @@ class ConfigurableTask(Task):
# and this stores our "regular" conditional loglikelihoods # and this stores our "regular" conditional loglikelihoods
lls = lls[::2] lls = lls[::2]
pred = np.argmax(lls) if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
else:
gold = self.doc_to_target(doc)
pred = choices[np.argmax(lls)]
pred_norm = choices[np.argmax(lls / completion_len)]
acc = 1.0 if np.argmax(lls) == gold else 0.0 acc = 1.0 if pred == gold else 0.0
completion_len = np.array([float(len(i)) for i in choices]) acc_norm = 1.0 if pred_norm == gold else 0.0
acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment