Commit 3b50b941 authored by lintangsutawika's avatar lintangsutawika
Browse files

updated winogrande

parent f33c6c9b
...@@ -43,7 +43,7 @@ ALL_OUTPUT_TYPES = [ ...@@ -43,7 +43,7 @@ ALL_OUTPUT_TYPES = [
"multiple_choice", "multiple_choice",
"loglikelihood_rolling", "loglikelihood_rolling",
"greedy_until", "greedy_until",
"winograd_schema" "winograd_schema",
] ]
...@@ -634,7 +634,6 @@ class ConfigurableTask(Task): ...@@ -634,7 +634,6 @@ class ConfigurableTask(Task):
for key, alias in self._config.template_aliases: for key, alias in self._config.template_aliases:
self.dataset.rename_column(key, alias) self.dataset.rename_column(key, alias)
if self.has_test_docs(): if self.has_test_docs():
docs = self.test_docs() docs = self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs():
...@@ -650,14 +649,13 @@ class ConfigurableTask(Task): ...@@ -650,14 +649,13 @@ class ConfigurableTask(Task):
if self._config.output_type == "multiple_choice": if self._config.output_type == "multiple_choice":
if type(test_text) is list: if type(test_text) is list:
self.multiple_input = True self.multiple_input = len(test_text)
elif type(test_text) is str: elif type(test_text) is str:
self.multiple_input = False self.multiple_input = 0
test_choice = self.doc_choice(test_doc) # test_choice = self.doc_choice(test_doc)
# test_target = self.doc_to_target(test_doc) # test_target = self.doc_to_target(test_doc)
def download(self, dataset_kwargs=None): def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
...@@ -737,7 +735,7 @@ class ConfigurableTask(Task): ...@@ -737,7 +735,7 @@ class ConfigurableTask(Task):
return utils.apply_template(doc_to_text, doc) return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text): elif callable(doc_to_text):
return doc_to_text(doc) return doc_to_text(doc)
# Used when applyting a Promptsource template # Used when applying a Promptsource template
elif hasattr(doc_to_text, "apply"): elif hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0] return doc_to_text.apply(doc)[0]
else: else:
...@@ -755,7 +753,7 @@ class ConfigurableTask(Task): ...@@ -755,7 +753,7 @@ class ConfigurableTask(Task):
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target): elif callable(doc_to_target):
return doc_to_target(doc) return doc_to_target(doc)
# Used when applyting a Promptsource template # Used when applying a Promptsource template
elif hasattr(doc_to_target, "apply"): elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1] return doc_to_target.apply(doc)[1]
else: else:
...@@ -790,17 +788,13 @@ class ConfigurableTask(Task): ...@@ -790,17 +788,13 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list. # we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this? # TODO: any cleaner way to do this?
if self.multiple_input: if self.multiple_input > 0:
choices = self.doc_to_text(doc) choices = self.doc_to_text(doc)
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
arguments = [ arguments = [(ctx, " {}".format(cont)) for ctx in choices]
(ctx, " {}".format(cont)) for ctx in choices
]
else: else:
cont = self.create_choices(doc) cont = self.create_choices(doc)
arguments = [ arguments = [(ctx, " {}".format(cont)) for cont in choices]
(ctx, " {}".format(cont)) for cont in choices
]
request_list = [ request_list = [
Instance( Instance(
...@@ -901,7 +895,11 @@ class ConfigurableTask(Task): ...@@ -901,7 +895,11 @@ class ConfigurableTask(Task):
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
if self.multiple_input:
choices = [self.doc_to_text(doc)] * self.multiple_input
else:
choices = self.create_choices(doc) choices = self.create_choices(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
if ( if (
......
def doc_to_text(doc): def doc_to_text(doc):
idx = doc["sentence"].index("_") idx = doc["sentence"].index("_")
return [doc["sentence"][:idx] + opt for opt in doc["option1"]] options = [doc["option1"], doc["option2"]]
return [doc["sentence"][:idx] + opt for opt in options]
def doc_to_target(doc): def doc_to_target(doc):
idx = doc["sentence"].index("_") + 1 idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip() return doc["sentence"][idx:].strip()
def partial_context(doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
return doc["sentence"][:pronoun_loc] + option
def partial_target(doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def create_choices(doc):
choices = []
for option in [doc["option1"], doc["option2"]]:
partial_ctx = partial_context(doc, option)
choices.append(partial_ctx)
return choices
def gold_alias(doc): def gold_alias(doc):
answer_to_num = {"1": 0, "2": 1} answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc['answer']] return answer_to_num[doc["answer"]]
\ No newline at end of file
...@@ -6,7 +6,7 @@ training_split: train ...@@ -6,7 +6,7 @@ training_split: train
validation_split: validation validation_split: validation
doc_to_target: !function preprocess_winogrande.doc_to_target doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_text: !function preprocess_winogrande.doc_to_text doc_to_text: !function preprocess_winogrande.doc_to_text
# gold_alias: !function preprocess_winogrande.gold_alias gold_alias: !function preprocess_winogrande.gold_alias
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment