Commit 74f3307b authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Fix few-shot contexts for `Winograd Schema Challenge`-based tasks

parent 9b14235d
......@@ -13,6 +13,8 @@ class Winogrande(HFTask):
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
answer_to_num = {'1': 0, '2': 1}
def has_training_docs(self):
return True
......@@ -20,20 +22,20 @@ class Winogrande(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the sentence with each candidate choice
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
context1 = doc["sentence"][:pronoun_loc] + doc["option1"]
context2 = doc["sentence"][:pronoun_loc] + doc["option2"]
return context1, context2
return doc["sentence"][:pronoun_loc] + option
def doc_to_target(self, doc):
return " " + self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
......@@ -41,15 +43,21 @@ class Winogrande(HFTask):
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
def doc_to_target(self, doc):
return self.partial_target(doc)
def fewshot_context(self, doc, num_fewshot, provide_description):
fewshot_ctx = super().fewshot_context(doc, num_fewshot, provide_description)
return fewshot_ctx + "\n" + self._wrong_partial_context(doc)
def _wrong_partial_context(self, doc):
wrong_answer = f"{int(not self.answer_to_num[doc['answer']]) + 1}"
wrong_option = doc["option" + wrong_answer]
return self.partial_context(doc, wrong_option)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -60,10 +68,19 @@ class Winogrande(HFTask):
part of the document for `doc`.
"""
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
right_ctx, wrong_ctx = self.split_fewshot_context(ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
@classmethod
def split_fewshot_context(cls, ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
partial_ctxs = ctx.pop().split("\n")
# NOTE: First context in `partial_ctxs` is always right because of `doc_to_text`.
right_ctx = "\n\n".join([*ctx, partial_ctxs[0]]) if ctx else partial_ctxs[0]
wrong_ctx = "\n\n".join([*ctx, partial_ctxs[1]]) if ctx else partial_ctxs[1]
return right_ctx, wrong_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -75,9 +92,8 @@ class Winogrande(HFTask):
:param results:
The results of the requests created in construct_requests.
"""
answer = int(doc["answer"]) - 1 # `- 1` b/c doc["answer"] ∈ {'1', '2'}
return {
"acc": np.argmax(results) == answer
"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]
}
def aggregation(self):
......
......@@ -26,12 +26,12 @@ class WinogradSchemaChallenge273(HFTask):
data = []
for doc in self.data["test"]:
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc["options"][0], doc)
doc["options"][1] = self.__normalize_option(doc["options"][1], doc)
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
data.append(doc)
return {"test": data}
def __normalize_option(self, option, doc):
def __normalize_option(self, doc, option):
# Append `'s` to possessive determiner based options.
if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
option += "'s"
......@@ -51,22 +51,17 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self):
return True
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]])
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the original text with each candidate
# choice and ignore everything after.
context1 = doc["text"][:doc["pronoun_loc"]] + doc["options"][0]
context2 = doc["text"][:doc["pronoun_loc"]] + doc["options"][1]
return context1, context2
def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified
# option and ignore everything after.
return doc["text"][:doc["pronoun_loc"]] + option
def doc_to_target(self, doc):
return " " + self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
......@@ -74,15 +69,26 @@ class WinogradSchemaChallenge273(HFTask):
start_index = doc["pronoun_loc"] + len(doc["pronoun"])
return doc["text"][start_index:].strip()
def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def doc_to_target(self, doc):
return self.partial_target(doc)
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def fewshot_context(self, doc, num_fewshot, provide_description):
fewshot_ctx = super().fewshot_context(doc, num_fewshot, provide_description)
return fewshot_ctx + "\n" + self._wrong_partial_context(doc)
def _wrong_partial_context(self, doc):
wrong_answer = int(not doc["label"])
wrong_option = doc["options"][wrong_answer]
return self.partial_context(doc, wrong_option)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -93,10 +99,19 @@ class WinogradSchemaChallenge273(HFTask):
part of the document for `doc`.
"""
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
right_ctx, wrong_ctx = self.split_fewshot_context(ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
@classmethod
def split_fewshot_context(cls, ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
partial_ctxs = ctx.pop().split("\n")
# NOTE: First context in `partial_ctxs` is always right because of `doc_to_text`.
right_ctx = "\n\n".join([*ctx, partial_ctxs[0]]) if ctx else partial_ctxs[0]
wrong_ctx = "\n\n".join([*ctx, partial_ctxs[1]]) if ctx else partial_ctxs[1]
return right_ctx, wrong_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment