"cacheflow/engine/async_llm_engine.py" did not exist on "da5ddcd544ac5ce6bc4f522af9cbdc315f94620e"
Commit 3a90e246 authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Remove incorrect context from `fewshot_context`

parent 74f3307b
......@@ -27,6 +27,10 @@ class Winogrande(HFTask):
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
......@@ -43,19 +47,6 @@ class Winogrande(HFTask):
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
def fewshot_context(self, doc, num_fewshot, provide_description):
fewshot_ctx = super().fewshot_context(doc, num_fewshot, provide_description)
return fewshot_ctx + "\n" + self._wrong_partial_context(doc)
def _wrong_partial_context(self, doc):
wrong_answer = f"{int(not self.answer_to_num[doc['answer']]) + 1}"
wrong_option = doc["option" + wrong_answer]
return self.partial_context(doc, wrong_option)
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -68,19 +59,18 @@ class Winogrande(HFTask):
part of the document for `doc`.
"""
target = self.partial_target(doc)
right_ctx, wrong_ctx = self.split_fewshot_context(ctx)
right_ctx, wrong_ctx = ctx, self._wrong_context(doc, ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
@classmethod
def split_fewshot_context(cls, ctx):
def _wrong_context(self, doc, ctx):
wrong_answer = f"{int(not self.answer_to_num[doc['answer']]) + 1}"
wrong_option = doc["option" + wrong_answer]
wrong_ctx = self.partial_context(doc, wrong_option)
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
partial_ctxs = ctx.pop().split("\n")
# NOTE: First context in `partial_ctxs` is always right because of `doc_to_text`.
right_ctx = "\n\n".join([*ctx, partial_ctxs[0]]) if ctx else partial_ctxs[0]
wrong_ctx = "\n\n".join([*ctx, partial_ctxs[1]]) if ctx else partial_ctxs[1]
return right_ctx, wrong_ctx
ctx.pop() # Remove the correct context.
return "\n\n".join([*ctx, wrong_ctx]) if ctx else wrong_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -51,6 +51,15 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]])
......@@ -69,24 +78,6 @@ class WinogradSchemaChallenge273(HFTask):
start_index = doc["pronoun_loc"] + len(doc["pronoun"])
return doc["text"][start_index:].strip()
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def fewshot_context(self, doc, num_fewshot, provide_description):
fewshot_ctx = super().fewshot_context(doc, num_fewshot, provide_description)
return fewshot_ctx + "\n" + self._wrong_partial_context(doc)
def _wrong_partial_context(self, doc):
wrong_answer = int(not doc["label"])
wrong_option = doc["options"][wrong_answer]
return self.partial_context(doc, wrong_option)
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -99,19 +90,18 @@ class WinogradSchemaChallenge273(HFTask):
part of the document for `doc`.
"""
target = self.partial_target(doc)
right_ctx, wrong_ctx = self.split_fewshot_context(ctx)
right_ctx, wrong_ctx = ctx, self._wrong_context(doc, ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
@classmethod
def split_fewshot_context(cls, ctx):
def _wrong_context(self, doc, ctx):
wrong_answer = int(not doc["label"])
wrong_option = doc["options"][wrong_answer]
wrong_ctx = self.partial_context(doc, wrong_option)
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
partial_ctxs = ctx.pop().split("\n")
# NOTE: First context in `partial_ctxs` is always right because of `doc_to_text`.
right_ctx = "\n\n".join([*ctx, partial_ctxs[0]]) if ctx else partial_ctxs[0]
wrong_ctx = "\n\n".join([*ctx, partial_ctxs[1]]) if ctx else partial_ctxs[1]
return right_ctx, wrong_ctx
ctx.pop() # Remove the correct context.
return "\n\n".join([*ctx, wrong_ctx]) if ctx else wrong_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment