Commit e4e9228e authored by Jason Phang's avatar Jason Phang
Browse files

ReCoRD fix

parent 487b5313
...@@ -44,7 +44,7 @@ TASK_REGISTRY = { ...@@ -44,7 +44,7 @@ TASK_REGISTRY = {
"cb": superglue.CommitmentBank, "cb": superglue.CommitmentBank,
"copa": superglue.Copa, "copa": superglue.Copa,
"multirc": superglue.MultiRC, "multirc": superglue.MultiRC,
#"record": superglue.ReCoRD, "record": superglue.ReCoRD,
"wic": superglue.WordsInContext, "wic": superglue.WordsInContext,
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
......
...@@ -272,7 +272,7 @@ class ReCoRD(HFTask): ...@@ -272,7 +272,7 @@ class ReCoRD(HFTask):
def training_docs(self): def training_docs(self):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing. # In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no. # Each doc consists of multiple answer candidates, each of which is scored yes/no.
# Hence, we one "doc" for each (context + passage, answer) pair. # Hence, we create one "doc" for each (context + passage, answer) pair.
# Moreover, we only use the correct answers for context packing # Moreover, we only use the correct answers for context packing
# (This is not an issue for evaluation, where we can directly score multiple candidates at once). # (This is not an issue for evaluation, where we can directly score multiple candidates at once).
if self._training_docs is None: if self._training_docs is None:
...@@ -288,14 +288,10 @@ class ReCoRD(HFTask): ...@@ -288,14 +288,10 @@ class ReCoRD(HFTask):
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
for doc in self.data["validation"]: # Following from .trianing_docs, for validation_docs, each document corresponds to
for entity in list(set(doc["entities"])): # the original doc from the dataset, i.e. comprises of lists of entities, and which
yield { # entities are correct (potentially multiple)
"passage": doc["passage"], yield from self.data["validation"]
"query": doc["query"],
"entity": entity,
"label": entity in doc["answers"],
}
def doc_to_text(self, doc): def doc_to_text(self, doc):
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n") initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
...@@ -314,7 +310,7 @@ class ReCoRD(HFTask): ...@@ -314,7 +310,7 @@ class ReCoRD(HFTask):
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
requests = [ requests = [
rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity)) rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity))
for entity in doc["entity"] for entity in doc["entities"]
] ]
return requests return requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment