Commit 3b50b941 authored by lintangsutawika's avatar lintangsutawika
Browse files

updated winogrande

parent f33c6c9b
......@@ -43,7 +43,7 @@ ALL_OUTPUT_TYPES = [
"multiple_choice",
"loglikelihood_rolling",
"greedy_until",
"winograd_schema"
"winograd_schema",
]
......@@ -314,14 +314,14 @@ class Task(abc.ABC):
The processed version of the specified `doc`.
"""
return doc
def doc_to_choice(self, doc):
if self._config.doc_to_choice is None:
return ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
elif type(self._config.doc_to_choice) == str:
return utils.apply_template(self._config.doc_to_choice, doc)
else:
......@@ -634,7 +634,6 @@ class ConfigurableTask(Task):
for key, alias in self._config.template_aliases:
self.dataset.rename_column(key, alias)
if self.has_test_docs():
docs = self.test_docs()
elif self.has_validation_docs():
......@@ -650,14 +649,13 @@ class ConfigurableTask(Task):
if self._config.output_type == "multiple_choice":
if type(test_text) is list:
self.multiple_input = True
self.multiple_input = len(test_text)
elif type(test_text) is str:
self.multiple_input = False
test_choice = self.doc_choice(test_doc)
self.multiple_input = 0
# test_choice = self.doc_choice(test_doc)
# test_target = self.doc_to_target(test_doc)
def download(self, dataset_kwargs=None):
self.dataset = datasets.load_dataset(
......@@ -737,7 +735,7 @@ class ConfigurableTask(Task):
return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text):
return doc_to_text(doc)
# Used when applyting a Promptsource template
# Used when applying a Promptsource template
elif hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0]
else:
......@@ -755,7 +753,7 @@ class ConfigurableTask(Task):
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
# Used when applyting a Promptsource template
# Used when applying a Promptsource template
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
......@@ -789,18 +787,14 @@ class ConfigurableTask(Task):
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this?
if self.multiple_input:
# TODO: any cleaner way to do this?
if self.multiple_input > 0:
choices = self.doc_to_text(doc)
cont = self.doc_to_target(doc)
arguments = [
(ctx, " {}".format(cont)) for ctx in choices
]
arguments = [(ctx, " {}".format(cont)) for ctx in choices]
else:
cont = self.create_choices(doc)
arguments = [
(ctx, " {}".format(cont)) for cont in choices
]
arguments = [(ctx, " {}".format(cont)) for cont in choices]
request_list = [
Instance(
......@@ -843,7 +837,7 @@ class ConfigurableTask(Task):
contexts = self.doc_to_choice(doc)
choice = self.doc_to_target(doc)
request_list = [
Instance(
request_type="loglikelihood",
......@@ -854,7 +848,7 @@ class ConfigurableTask(Task):
)
for i, context in enumerate(contexts)
]
return request_list
return Instance(
......@@ -901,7 +895,11 @@ class ConfigurableTask(Task):
lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = self.create_choices(doc)
if self.multiple_input:
choices = [self.doc_to_text(doc)] * self.multiple_input
else:
choices = self.create_choices(doc)
completion_len = np.array([float(len(i)) for i in choices])
if (
......@@ -917,7 +915,7 @@ class ConfigurableTask(Task):
pred_idx = np.argmax(lls)
pred_idx_norm = np.argmax(lls / completion_len)
if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
pred = pred_idx
......
def doc_to_text(doc):
idx = doc["sentence"].index("_")
return [doc["sentence"][:idx] + opt for opt in doc["option1"]]
options = [doc["option1"], doc["option2"]]
return [doc["sentence"][:idx] + opt for opt in options]
def doc_to_target(doc):
idx = doc["sentence"].index("_") + 1
return doc["sentence"][idx:].strip()
def partial_context(doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
return doc["sentence"][:pronoun_loc] + option
def partial_target(doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def create_choices(doc):
choices = []
for option in [doc["option1"], doc["option2"]]:
partial_ctx = partial_context(doc, option)
choices.append(partial_ctx)
return choices
def gold_alias(doc):
answer_to_num = {"1": 0, "2": 1}
return answer_to_num[doc['answer']]
\ No newline at end of file
return answer_to_num[doc["answer"]]
......@@ -6,7 +6,7 @@ training_split: train
validation_split: validation
doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_text: !function preprocess_winogrande.doc_to_text
# gold_alias: !function preprocess_winogrande.gold_alias
gold_alias: !function preprocess_winogrande.gold_alias
metric_list:
- metric: acc
aggregation: mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment