Commit 656c310e authored by lintangsutawika's avatar lintangsutawika
Browse files

process update to process doc_to_target variety

parent 3b50b941
......@@ -322,6 +322,8 @@ class Task(abc.ABC):
self._config.template_aliases + "{{answer_choices}}", doc
)
)
elif type(self._config.doc_to_choice) == dict:
return list(self._config.doc_to_choice.values())
elif type(self._config.doc_to_choice) == str:
return utils.apply_template(self._config.doc_to_choice, doc)
else:
......@@ -645,16 +647,21 @@ class ConfigurableTask(Task):
# Test One Doc
test_doc = docs[0]
self.features = list(test_doc.keys())
test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
# test_choice = self.doc_to_choice(test_doc)
if self._config.output_type == "multiple_choice":
if type(test_text) is list:
self.multiple_input = len(test_text)
elif type(test_text) is str:
self.multiple_input = 0
# test_choice = self.doc_choice(test_doc)
# test_target = self.doc_to_target(test_doc)
if type(test_target) is list:
self.multiple_output = len(test_target)
else:
self.multiple_output = 0
def download(self, dataset_kwargs=None):
......@@ -732,7 +739,10 @@ class ConfigurableTask(Task):
doc_to_text = self._config.doc_to_text
if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc)
if doc_to_text in self.features:
return doc[doc_to_text]
else:
return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text):
return doc_to_text(doc)
# Used when applying a Promptsource template
......@@ -750,7 +760,10 @@ class ConfigurableTask(Task):
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
if doc_to_target in self.features:
return doc[doc_to_target]
else:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
# Used when applying a Promptsource template
......@@ -793,7 +806,7 @@ class ConfigurableTask(Task):
cont = self.doc_to_target(doc)
arguments = [(ctx, " {}".format(cont)) for ctx in choices]
else:
cont = self.create_choices(doc)
choices = self.doc_to_choice(doc)
arguments = [(ctx, " {}".format(cont)) for cont in choices]
request_list = [
......@@ -896,12 +909,14 @@ class ConfigurableTask(Task):
# retrieve choices in List[str] form, to compute choice lengths, etc.
if self.multiple_input:
choices = [self.doc_to_text(doc)] * self.multiple_input
choices = [self.doc_to_target(doc)] * self.multiple_input
else:
choices = self.create_choices(doc)
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
if self.multiple_output:
pass
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys()
......@@ -916,15 +931,21 @@ class ConfigurableTask(Task):
pred_idx = np.argmax(lls)
pred_idx_norm = np.argmax(lls / completion_len)
# Gives priority to evaluate base on gold_alias
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
pred = pred_idx
pred_norm = pred_idx_norm
gold_idx = int(self.gold_alias(doc))
gold = gold_idx
else:
gold = self.doc_to_target(doc)
gold_idx = choices.index(gold)
pred = choices[pred_idx]
pred_norm = choices[pred_idx_norm]
gold = self.doc_to_target(doc)
if type(gold) == int:
gold_idx = gold
gold = choices[gold_idx]
elif type(gold) == str:
gold_idx = choices.index(gold)
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
......
......@@ -7,12 +7,7 @@ output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
doc_to_target: label
doc_to_choice: {0: "no", 1: "yes"}
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
- metric: acc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment