"vscode:/vscode.git/clone" did not exist on "db5ef3004c539688f15dbd3fb3ee9d8c0b48fe05"
Commit 2cfdd80a authored by Jason Phang's avatar Jason Phang
Browse files

ReCoRD fixup

parent e4e9228e
...@@ -272,26 +272,25 @@ class ReCoRD(HFTask): ...@@ -272,26 +272,25 @@ class ReCoRD(HFTask):
def training_docs(self): def training_docs(self):
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing. # In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
# Each doc consists of multiple answer candidates, each of which is scored yes/no. # Each doc consists of multiple answer candidates, each of which is scored yes/no.
# Hence, we create one "doc" for each (context + passage, answer) pair.
# Moreover, we only use the correct answers for context packing
# (This is not an issue for evaluation, where we can directly score multiple candidates at once).
if self._training_docs is None: if self._training_docs is None:
self._training_docs = [] self._training_docs = []
for doc in self.data["train"]: for doc in self.data["train"]:
for entity in list(set(doc["entities"])): self._training_docs.append(self._process_doc(doc))
self._training_docs.append({
"passage": doc["passage"],
"query": doc["query"],
"entity": entity,
"label": entity in doc["answers"],
})
return self._training_docs return self._training_docs
def validation_docs(self): def validation_docs(self):
# Following from .trianing_docs, for validation_docs, each document corresponds to # See: training_docs
# the original doc from the dataset, i.e. comprises of lists of entities, and which for doc in self.data["validation"]:
# entities are correct (potentially multiple) yield self._process_doc(doc)
yield from self.data["validation"]
@classmethod
def _process_doc(cls, doc):
return {
"passage": doc["passage"],
"query": doc["query"],
"entities": sorted(list(set(doc["entities"]))),
"answers": sorted(list(set(doc["answers"]))),
}
def doc_to_text(self, doc): def doc_to_text(self, doc):
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n") initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
...@@ -305,7 +304,8 @@ class ReCoRD(HFTask): ...@@ -305,7 +304,8 @@ class ReCoRD(HFTask):
return f' - {query}'.replace("@placeholder", entity) return f' - {query}'.replace("@placeholder", entity)
def doc_to_target(self, doc): def doc_to_target(self, doc):
return self.format_answer(query=doc["query"], entity=doc["entity"]) # We only output the first correct entity in a doc
return self.format_answer(query=doc["query"], entity=doc["answers"][0])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
requests = [ requests = [
...@@ -319,10 +319,10 @@ class ReCoRD(HFTask): ...@@ -319,10 +319,10 @@ class ReCoRD(HFTask):
# - Pick the maximum likelihood prediction entity # - Pick the maximum likelihood prediction entity
# - Evaluate the accuracy and token F1 PER EXAMPLE # - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples # - Average over all examples
max_idx = np.argmax(np.array(results)) max_idx = np.argmax(np.array([result[0] for result in results]))
prediction = doc["entities"][max_idx] prediction = doc["entities"][max_idx]
gold_label_set = list(set(doc["answers"])) gold_label_set = doc["answers"]
f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set) f1 = metric_max_over_ground_truths(squad_metrics.compute_f1, prediction, gold_label_set)
em = metric_max_over_ground_truths(squad_metrics.compute_exact, prediction, gold_label_set) em = metric_max_over_ground_truths(squad_metrics.compute_exact, prediction, gold_label_set)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment