Unverified Commit c6b094e0 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #138 from jon-tow/winograd-fixes

Fix few-shot contexts for `Winograd Schema Challenge`-based tasks
parents 7614a8f3 1a159d6b
......@@ -13,6 +13,8 @@ class Winogrande(HFTask):
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
answer_to_num = {'1': 0, '2': 1}
def has_training_docs(self):
return True
......@@ -20,54 +22,59 @@ class Winogrande(HFTask):
return True
def has_test_docs(self):
return True
return False
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the sentence with each candidate choice
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
context1 = doc["sentence"][:pronoun_loc] + doc["option1"]
context2 = doc["sentence"][:pronoun_loc] + doc["option2"]
return context1, context2
return doc["sentence"][:pronoun_loc] + option
def doc_to_target(self, doc):
return self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def doc_to_target(self, doc):
return self.partial_target(doc)
return " " + doc["sentence"][pronoun_loc:].strip()
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
right_ctx, wrong_ctx = ctx, self._wrong_context(doc, ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
def _wrong_context(self, doc, ctx):
wrong_answer = f"{int(not self.answer_to_num[doc['answer']]) + 1}"
wrong_option = doc["option" + wrong_answer]
wrong_ctx = self.partial_context(doc, wrong_option)
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
ctx.pop() # Remove the correct context.
return "\n\n".join([*ctx, wrong_ctx]) if ctx else wrong_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -75,15 +82,14 @@ class Winogrande(HFTask):
:param results:
The results of the requests created in construct_requests.
"""
answer = int(doc["answer"]) - 1 # `- 1` b/c doc["answer"] ∈ {'1', '2'}
return {
"acc": np.argmax(results) == answer
"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
......@@ -93,7 +99,7 @@ class Winogrande(HFTask):
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
......
......@@ -26,14 +26,14 @@ class WinogradSchemaChallenge273(HFTask):
data = []
for doc in self.data["test"]:
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc["options"][0], doc)
doc["options"][1] = self.__normalize_option(doc["options"][1], doc)
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
data.append(doc)
return {"test": data}
def __normalize_option(self, option, doc):
def __normalize_option(self, doc, option):
# Append `'s` to possessive determiner based options.
if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
option += "'s"
# Appropriately lowercase the pronoun in the option.
pronoun = option.split()[0]
......@@ -51,56 +51,61 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]])
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the original text with each candidate
# choice and ignore everything after.
context1 = doc["text"][:doc["pronoun_loc"]] + doc["options"][0]
context2 = doc["text"][:doc["pronoun_loc"]] + doc["options"][1]
return context1, context2
def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified
# option and ignore everything after.
return doc["text"][:doc["pronoun_loc"]] + option
def doc_to_target(self, doc):
return self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
start_index = doc["pronoun_loc"] + len(doc["pronoun"])
return doc["text"][start_index:].strip()
def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def doc_to_target(self, doc):
return self.partial_target(doc)
return " " + doc["text"][start_index:].strip()
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
right_ctx, wrong_ctx = ctx, self._wrong_context(doc, ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
def _wrong_context(self, doc, ctx):
wrong_answer = int(not doc["label"])
wrong_option = doc["options"][wrong_answer]
wrong_ctx = self.partial_context(doc, wrong_option)
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
ctx.pop() # Remove the correct context.
return "\n\n".join([*ctx, wrong_ctx]) if ctx else wrong_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -115,7 +120,7 @@ class WinogradSchemaChallenge273(HFTask):
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
......@@ -125,7 +130,7 @@ class WinogradSchemaChallenge273(HFTask):
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment