Unverified Commit 7031c324 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #98 from zphang/wsc

SGWinogradSchemaChallenge
parents 94d782a0 21d527d4
...@@ -33,7 +33,7 @@ TASK_REGISTRY = { ...@@ -33,7 +33,7 @@ TASK_REGISTRY = {
"multirc": superglue.MultiRC, "multirc": superglue.MultiRC,
"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
#"wsc": superglue.SGWinogradSchemaChallenge, # not implemented yet "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
......
...@@ -401,6 +401,8 @@ class WordsInContext(HFTask): ...@@ -401,6 +401,8 @@ class WordsInContext(HFTask):
class SGWinogradSchemaChallenge(HFTask): class SGWinogradSchemaChallenge(HFTask):
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task.
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "wsc" DATASET_NAME = "wsc"
...@@ -425,7 +427,6 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -425,7 +427,6 @@ class SGWinogradSchemaChallenge(HFTask):
return self._training_docs return self._training_docs
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out actual description
return "Final Exam with Answer Key\n" \ return "Final Exam with Answer Key\n" \
"Instructions: Please carefully read the following passages. " \ "Instructions: Please carefully read the following passages. " \
"For each passage, you must identify which noun the pronoun marked in *bold*" \ "For each passage, you must identify which noun the pronoun marked in *bold*" \
...@@ -438,24 +439,34 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -438,24 +439,34 @@ class SGWinogradSchemaChallenge(HFTask):
+ "*{}*".format(doc["span2_text"]) + "*{}*".format(doc["span2_text"])
+ raw_passage[doc["span2_index"] + len(doc["span2_text"]):] + raw_passage[doc["span2_index"] + len(doc["span2_text"]):]
) )
noun = doc["span1_text"]
pronoun = doc["span2_text"] pronoun = doc["span2_text"]
text = ( text = (
f"Passage: {passage}\n" f"Passage: {passage}\n"
+ f"Question: In the passage above, what does the pronoun \"*{pronoun}*\" refer to?\n" + f"Question: In the passage above, does the pronoun \"*{pronoun}*\" refer to \"*{noun}*\"?\n"
+ "Answer:" + "Answer:"
) )
return text return text
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " {}".format(doc["span1_text"]) return " " + yesno(doc['label'])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
# Evaluate probability of generating answer based on span1_text (coref target)
raise NotImplementedError("requires free-form generation") ll_yes, _ = rf.loglikelihood(ctx, ' yes')
ll_no, _ = rf.loglikelihood(ctx, ' no')
return ll_yes, ll_no
def process_results(self, doc, results): def process_results(self, doc, results):
# Evaluate probability of generating answer based on span1_text (coref target) ll_yes, ll_no = results
raise NotImplementedError("requires evaluation from free-form generation") gold = doc["label"]
acc = 1. if (ll_yes > ll_no) == gold else 0.
return {
"acc": acc
}
def higher_is_better(self): def higher_is_better(self):
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment