Commit 656c310e authored by lintangsutawika's avatar lintangsutawika
Browse files

process update to process doc_to_target variety

parent 3b50b941
...@@ -322,6 +322,8 @@ class Task(abc.ABC): ...@@ -322,6 +322,8 @@ class Task(abc.ABC):
self._config.template_aliases + "{{answer_choices}}", doc self._config.template_aliases + "{{answer_choices}}", doc
) )
) )
elif type(self._config.doc_to_choice) == dict:
return list(self._config.doc_to_choice.values())
elif type(self._config.doc_to_choice) == str: elif type(self._config.doc_to_choice) == str:
return utils.apply_template(self._config.doc_to_choice, doc) return utils.apply_template(self._config.doc_to_choice, doc)
else: else:
...@@ -645,16 +647,21 @@ class ConfigurableTask(Task): ...@@ -645,16 +647,21 @@ class ConfigurableTask(Task):
# Test One Doc # Test One Doc
test_doc = docs[0] test_doc = docs[0]
self.features = list(test_doc.keys())
test_text = self.doc_to_text(test_doc) test_text = self.doc_to_text(test_doc)
test_target = self.doc_to_target(test_doc)
# test_choice = self.doc_to_choice(test_doc)
if self._config.output_type == "multiple_choice": if self._config.output_type == "multiple_choice":
if type(test_text) is list: if type(test_text) is list:
self.multiple_input = len(test_text) self.multiple_input = len(test_text)
elif type(test_text) is str: elif type(test_text) is str:
self.multiple_input = 0 self.multiple_input = 0
# test_choice = self.doc_choice(test_doc)
# test_target = self.doc_to_target(test_doc) if type(test_target) is list:
self.multiple_output = len(test_target)
else:
self.multiple_output = 0
def download(self, dataset_kwargs=None): def download(self, dataset_kwargs=None):
...@@ -732,7 +739,10 @@ class ConfigurableTask(Task): ...@@ -732,7 +739,10 @@ class ConfigurableTask(Task):
doc_to_text = self._config.doc_to_text doc_to_text = self._config.doc_to_text
if type(doc_to_text) == str: if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc) if doc_to_text in self.features:
return doc[doc_to_text]
else:
return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text): elif callable(doc_to_text):
return doc_to_text(doc) return doc_to_text(doc)
# Used when applying a Promptsource template # Used when applying a Promptsource template
...@@ -750,7 +760,10 @@ class ConfigurableTask(Task): ...@@ -750,7 +760,10 @@ class ConfigurableTask(Task):
doc_to_target = self._config.doc_to_target doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str: if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc) if doc_to_target in self.features:
return doc[doc_to_target]
else:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) return doc_to_target(doc)
# Used when applying a Promptsource template # Used when applying a Promptsource template
...@@ -793,7 +806,7 @@ class ConfigurableTask(Task): ...@@ -793,7 +806,7 @@ class ConfigurableTask(Task):
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
arguments = [(ctx, " {}".format(cont)) for ctx in choices] arguments = [(ctx, " {}".format(cont)) for ctx in choices]
else: else:
cont = self.create_choices(doc) choices = self.doc_to_choice(doc)
arguments = [(ctx, " {}".format(cont)) for cont in choices] arguments = [(ctx, " {}".format(cont)) for cont in choices]
request_list = [ request_list = [
...@@ -896,12 +909,14 @@ class ConfigurableTask(Task): ...@@ -896,12 +909,14 @@ class ConfigurableTask(Task):
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
if self.multiple_input: if self.multiple_input:
choices = [self.doc_to_text(doc)] * self.multiple_input choices = [self.doc_to_target(doc)] * self.multiple_input
else: else:
choices = self.create_choices(doc) choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
if self.multiple_output:
pass
if ( if (
2 * len(choices) == len(lls) 2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_fn_list.keys() and "acc_mutual_info" in self._metric_fn_list.keys()
...@@ -916,15 +931,21 @@ class ConfigurableTask(Task): ...@@ -916,15 +931,21 @@ class ConfigurableTask(Task):
pred_idx = np.argmax(lls) pred_idx = np.argmax(lls)
pred_idx_norm = np.argmax(lls / completion_len) pred_idx_norm = np.argmax(lls / completion_len)
# Gives priority to evaluate base on gold_alias
if self._config.gold_alias is not None: if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
pred = pred_idx pred = pred_idx
pred_norm = pred_idx_norm pred_norm = pred_idx_norm
gold_idx = int(self.gold_alias(doc))
gold = gold_idx
else: else:
gold = self.doc_to_target(doc)
gold_idx = choices.index(gold)
pred = choices[pred_idx] pred = choices[pred_idx]
pred_norm = choices[pred_idx_norm] pred_norm = choices[pred_idx_norm]
gold = self.doc_to_target(doc)
if type(gold) == int:
gold_idx = gold
gold = choices[gold_idx]
elif type(gold) == str:
gold_idx = choices.index(gold)
acc = 1.0 if pred == gold else 0.0 acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0 acc_norm = 1.0 if pred_norm == gold else 0.0
......
...@@ -7,12 +7,7 @@ output_type: multiple_choice ...@@ -7,12 +7,7 @@ output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[label]}}" doc_to_target: label
gold_alias: "{{label}}" # this will be cast to an int. doc_to_choice: {0: "no", 1: "yes"}
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
metric_list: metric_list:
- metric: exact_match - metric: acc
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment