Unverified Commit 9adf18b1 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #141 from jon-tow/winograd-fixes2

Fix context orderings in `Winograds` to match label positioning
parents c6b094e0 155aeee9
......@@ -59,18 +59,18 @@ class Winogrande(HFTask):
part of the document for `doc`.
"""
target = self.partial_target(doc)
right_ctx, wrong_ctx = ctx, self._wrong_context(doc, ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
def _wrong_context(self, doc, ctx):
wrong_answer = f"{int(not self.answer_to_num[doc['answer']]) + 1}"
wrong_option = doc["option" + wrong_answer]
wrong_ctx = self.partial_context(doc, wrong_option)
lls = []
for option in [doc["option1"], doc["option2"]]:
partial_ctx = self.partial_context(doc, option)
full_ctx = self.append_context(ctx, partial_ctx)
lls.append(rf.loglikelihood(full_ctx, target)[0])
return lls
@classmethod
def append_context(cls, ctx, partial_ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
ctx.pop() # Remove the correct context.
return "\n\n".join([*ctx, wrong_ctx]) if ctx else wrong_ctx
ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
......@@ -90,18 +90,18 @@ class WinogradSchemaChallenge273(HFTask):
part of the document for `doc`.
"""
target = self.partial_target(doc)
right_ctx, wrong_ctx = ctx, self._wrong_context(doc, ctx)
ll_right_ctx, _ = rf.loglikelihood(right_ctx, target)
ll_wrong_ctx, _ = rf.loglikelihood(wrong_ctx, target)
return ll_right_ctx, ll_wrong_ctx
def _wrong_context(self, doc, ctx):
wrong_answer = int(not doc["label"])
wrong_option = doc["options"][wrong_answer]
wrong_ctx = self.partial_context(doc, wrong_option)
lls = []
for option in doc["options"]:
partial_ctx = self.partial_context(doc, option)
full_ctx = self.append_context(ctx, partial_ctx)
lls.append(rf.loglikelihood(full_ctx, target)[0])
return lls
@classmethod
def append_context(cls, ctx, partial_ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
ctx.pop() # Remove the correct context.
return "\n\n".join([*ctx, wrong_ctx]) if ctx else wrong_ctx
ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment