Commit f33c6c9b authored by Lintang Sutawika's avatar Lintang Sutawika
Browse files

Testing for multiple inputs from doc_to_text

parent 552b09e0
...@@ -315,17 +315,17 @@ class Task(abc.ABC): ...@@ -315,17 +315,17 @@ class Task(abc.ABC):
""" """
return doc return doc
def create_choices(self, doc): def doc_to_choice(self, doc):
if self._config.create_choices is None: if self._config.doc_to_choice is None:
return ast.literal_eval( return ast.literal_eval(
utils.apply_template( utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc self._config.template_aliases + "{{answer_choices}}", doc
) )
) )
elif type(self._config.create_choices) == str: elif type(self._config.doc_to_choice) == str:
return utils.apply_template(self._config.create_choices, doc) return utils.apply_template(self._config.doc_to_choice, doc)
else: else:
return self._config.create_choices(doc) return self._config.doc_to_choice(doc)
@property @property
def instances(self): def instances(self):
...@@ -479,7 +479,10 @@ class Task(abc.ABC): ...@@ -479,7 +479,10 @@ class Task(abc.ABC):
) )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if type(example) == str:
return labeled_examples + example return labeled_examples + example
elif type(example) == list:
return [labeled_examples + ex for ex in example]
def apply_filters(self): def apply_filters(self):
...@@ -632,8 +635,6 @@ class ConfigurableTask(Task): ...@@ -632,8 +635,6 @@ class ConfigurableTask(Task):
self.dataset.rename_column(key, alias) self.dataset.rename_column(key, alias)
def __post_init__(self):
if self.has_test_docs(): if self.has_test_docs():
docs = self.test_docs() docs = self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs():
...@@ -647,14 +648,14 @@ class ConfigurableTask(Task): ...@@ -647,14 +648,14 @@ class ConfigurableTask(Task):
test_doc = docs[0] test_doc = docs[0]
test_text = self.doc_to_text(test_doc) test_text = self.doc_to_text(test_doc)
if OUTPUT_TYPE == "multiple_choice": if self._config.output_type == "multiple_choice":
if type(test_text) is list: if type(test_text) is list:
self.multiple_input = True self.multiple_input = True
elif type(test_text) is str: elif type(test_text) is str:
self.multiple_input = False self.multiple_input = False
test_choice = self.doc_choice(test_doc) test_choice = self.doc_choice(test_doc)
test_target = self.doc_to_target(test_doc) # test_target = self.doc_to_target(test_doc)
def download(self, dataset_kwargs=None): def download(self, dataset_kwargs=None):
...@@ -805,7 +806,7 @@ class ConfigurableTask(Task): ...@@ -805,7 +806,7 @@ class ConfigurableTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=arguments, arguments=arg,
idx=i, idx=i,
**kwargs, **kwargs,
) )
...@@ -840,7 +841,7 @@ class ConfigurableTask(Task): ...@@ -840,7 +841,7 @@ class ConfigurableTask(Task):
# similar to multiple_choice task type except each request contains # similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation # multiple differing contexts with the same continuation
contexts = self.create_choices(doc) contexts = self.doc_to_choice(doc)
choice = self.doc_to_target(doc) choice = self.doc_to_target(doc)
request_list = [ request_list = [
......
def doc_to_text(doc):
idx = doc["sentence"].index("_")
return [doc["sentence"][:idx] + opt for opt in doc["option1"]]
def doc_to_target(doc):
idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip()
def partial_context(doc, option): def partial_context(doc, option):
# Substitute the pronoun in the sentence with the specified option # Substitute the pronoun in the sentence with the specified option
# and ignore everything after. # and ignore everything after.
......
task: winogrande task: winogrande
dataset_path: winogrande dataset_path: winogrande
dataset_name: winogrande_xl dataset_name: winogrande_xl
output_type: winograd_schema output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_target: !function preprocess_winogrande.partial_target doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_text: !function preprocess_winogrande.create_choices doc_to_text: !function preprocess_winogrande.doc_to_text
gold_alias: !function preprocess_winogrande.gold_alias # gold_alias: !function preprocess_winogrande.gold_alias
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment