Commit 74f3307b authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Fix few-shot contexts for `Winograd Schema Challenge`-based tasks

parent 9b14235d
......@@ -13,6 +13,8 @@ class Winogrande(HFTask):
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
answer_to_num = {'1': 0, '2': 1}
def has_training_docs(self):
return True
......@@ -20,20 +22,20 @@ class Winogrande(HFTask):
return True
def has_test_docs(self):
return True
return False
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the sentence with each candidate choice
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
context1 = doc["sentence"][:pronoun_loc] + doc["option1"]
context2 = doc["sentence"][:pronoun_loc] + doc["option2"]
return context1, context2
return doc["sentence"][:pronoun_loc] + option
def doc_to_target(self, doc):
return " " + self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
......@@ -41,33 +43,48 @@ class Winogrande(HFTask):
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
def doc_to_target(self, doc):
return self.partial_target(doc)
def fewshot_context(self, doc, num_fewshot, provide_description):
fewshot_ctx = super().fewshot_context(doc, num_fewshot, provide_description)
return fewshot_ctx + "\n" + self._wrong_partial_context(doc)
def _wrong_partial_context(self, doc):
wrong_answer = f"{int(not self.answer_to_num[doc['answer']]) + 1}"
wrong_option = doc["option" + wrong_answer]
return self.partial_context(doc, wrong_option)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
right_ctx, wrong_ctx = self.split_fewshot_context(ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
@classmethod
def split_fewshot_context(cls, ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
partial_ctxs = ctx.pop().split("\n")
# NOTE: First context in `partial_ctxs` is always right because of `doc_to_text`.
right_ctx = "\n\n".join([*ctx, partial_ctxs[0]]) if ctx else partial_ctxs[0]
wrong_ctx = "\n\n".join([*ctx, partial_ctxs[1]]) if ctx else partial_ctxs[1]
return right_ctx, wrong_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -75,15 +92,14 @@ class Winogrande(HFTask):
:param results:
The results of the requests created in construct_requests.
"""
answer = int(doc["answer"]) - 1 # `- 1` b/c doc["answer"] ∈ {'1', '2'}
return {
"acc": np.argmax(results) == answer
"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
......@@ -93,7 +109,7 @@ class Winogrande(HFTask):
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
......
......@@ -26,14 +26,14 @@ class WinogradSchemaChallenge273(HFTask):
data = []
for doc in self.data["test"]:
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc["options"][0], doc)
doc["options"][1] = self.__normalize_option(doc["options"][1], doc)
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
data.append(doc)
return {"test": data}
def __normalize_option(self, option, doc):
def __normalize_option(self, doc, option):
# Append `'s` to possessive determiner based options.
if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
option += "'s"
# Appropriately lowercase the pronoun in the option.
pronoun = option.split()[0]
......@@ -51,22 +51,17 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self):
return True
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]])
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the original text with each candidate
# choice and ignore everything after.
context1 = doc["text"][:doc["pronoun_loc"]] + doc["options"][0]
context2 = doc["text"][:doc["pronoun_loc"]] + doc["options"][1]
return context1, context2
def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified
# option and ignore everything after.
return doc["text"][:doc["pronoun_loc"]] + option
def doc_to_target(self, doc):
return " " + self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
......@@ -74,33 +69,53 @@ class WinogradSchemaChallenge273(HFTask):
start_index = doc["pronoun_loc"] + len(doc["pronoun"])
return doc["text"][start_index:].strip()
def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def doc_to_target(self, doc):
return self.partial_target(doc)
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def fewshot_context(self, doc, num_fewshot, provide_description):
fewshot_ctx = super().fewshot_context(doc, num_fewshot, provide_description)
return fewshot_ctx + "\n" + self._wrong_partial_context(doc)
def _wrong_partial_context(self, doc):
wrong_answer = int(not doc["label"])
wrong_option = doc["options"][wrong_answer]
return self.partial_context(doc, wrong_option)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
right_ctx, wrong_ctx = self.split_fewshot_context(ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
@classmethod
def split_fewshot_context(cls, ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
partial_ctxs = ctx.pop().split("\n")
# NOTE: First context in `partial_ctxs` is always right because of `doc_to_text`.
right_ctx = "\n\n".join([*ctx, partial_ctxs[0]]) if ctx else partial_ctxs[0]
wrong_ctx = "\n\n".join([*ctx, partial_ctxs[1]]) if ctx else partial_ctxs[1]
return right_ctx, wrong_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -115,7 +130,7 @@ class WinogradSchemaChallenge273(HFTask):
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
......@@ -125,7 +140,7 @@ class WinogradSchemaChallenge273(HFTask):
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment