Unverified Commit 3aeb95b6 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Very rough draft and scribbles

parent 884ba785
......@@ -628,6 +628,16 @@ class ConfigurableTask(Task):
list(self.fewshot_docs()), self, rnd=random.Random()
)
# Test One Doc
text_output = self.doc_to_text()
if type(text_output) is list:
self.multiple_input = True
else
self.multiple_input = False
doc_to_target_output = self.doc_to_target()
doc_to_choice_output = self.doc_choice()
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset(
......@@ -707,7 +717,8 @@ class ConfigurableTask(Task):
return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text):
return doc_to_text(doc)
if hasattr(doc_to_text, "apply"):
# Used when applyting a Promptsource template
elif hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0]
else:
print(type(doc_to_text))
......@@ -724,6 +735,7 @@ class ConfigurableTask(Task):
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
# Used when applyting a Promptsource template
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
......@@ -757,9 +769,13 @@ class ConfigurableTask(Task):
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this?
choices = self.create_choices(doc)
# TODO: any cleaner way to do this?
if self.multiple_input:
choices = self.doc_to_text(doc)
continuation = self.doc_to_target(doc)
else:
continuation = self.create_choices(doc)
request_list = [
Instance(
request_type="loglikelihood",
......@@ -857,13 +873,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = self.create_choices(doc)
completion_len = np.array([float(len(i)) for i in choices])
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
......@@ -875,11 +889,17 @@ class ConfigurableTask(Task):
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
pred = np.argmax(lls)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
else:
gold = self.doc_to_target(doc)
pred = choices[np.argmax(lls)]
pred_norm = choices[np.argmax(lls / completion_len)]
acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in choices])
acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment