Commit f33c6c9b authored by Lintang Sutawika's avatar Lintang Sutawika
Browse files

Testing for multiple inputs from doc_to_text

parent 552b09e0
......@@ -315,17 +315,17 @@ class Task(abc.ABC):
"""
return doc
def create_choices(self, doc):
if self._config.create_choices is None:
def doc_to_choice(self, doc):
if self._config.doc_to_choice is None:
return ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
elif type(self._config.create_choices) == str:
return utils.apply_template(self._config.create_choices, doc)
elif type(self._config.doc_to_choice) == str:
return utils.apply_template(self._config.doc_to_choice, doc)
else:
return self._config.create_choices(doc)
return self._config.doc_to_choice(doc)
@property
def instances(self):
......@@ -479,7 +479,10 @@ class Task(abc.ABC):
)
example = self.doc_to_text(doc)
return labeled_examples + example
if type(example) == str:
return labeled_examples + example
elif type(example) == list:
return [labeled_examples + ex for ex in example]
def apply_filters(self):
......@@ -632,8 +635,6 @@ class ConfigurableTask(Task):
self.dataset.rename_column(key, alias)
def __post_init__(self):
if self.has_test_docs():
docs = self.test_docs()
elif self.has_validation_docs():
......@@ -647,15 +648,15 @@ class ConfigurableTask(Task):
test_doc = docs[0]
test_text = self.doc_to_text(test_doc)
if OUTPUT_TYPE == "multiple_choice":
if self._config.output_type == "multiple_choice":
if type(test_text) is list:
self.multiple_input = True
elif type(test_text) is str:
self.multiple_input = False
test_choice = self.doc_choice(test_doc)
test_target = self.doc_to_target(test_doc)
test_choice = self.doc_choice(test_doc)
# test_target = self.doc_to_target(test_doc)
def download(self, dataset_kwargs=None):
......@@ -805,7 +806,7 @@ class ConfigurableTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=arguments,
arguments=arg,
idx=i,
**kwargs,
)
......@@ -840,7 +841,7 @@ class ConfigurableTask(Task):
# similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation
contexts = self.create_choices(doc)
contexts = self.doc_to_choice(doc)
choice = self.doc_to_target(doc)
request_list = [
......
def doc_to_text(doc):
idx = doc["sentence"].index("_")
return [doc["sentence"][:idx] + opt for opt in doc["option1"]]
def doc_to_target(doc):
idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip()
def partial_context(doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
......
task: winogrande
dataset_path: winogrande
dataset_name: winogrande_xl
output_type: winograd_schema
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_target: !function preprocess_winogrande.partial_target
doc_to_text: !function preprocess_winogrande.create_choices
gold_alias: !function preprocess_winogrande.gold_alias
doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_text: !function preprocess_winogrande.doc_to_text
# gold_alias: !function preprocess_winogrande.gold_alias
metric_list:
- metric: acc
aggregation: mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment