Commit b44aa554 authored by cjlovering's avatar cjlovering
Browse files

QNLI with PS integration.

parent 4c201b97
...@@ -146,7 +146,7 @@ class MNLIMismatched(MNLI): ...@@ -146,7 +146,7 @@ class MNLIMismatched(MNLI):
return self.dataset["test_mismatched"] return self.dataset["test_mismatched"]
class QNLI(Task): class QNLI(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "glue" DATASET_PATH = "glue"
DATASET_NAME = "qnli" DATASET_NAME = "qnli"
...@@ -168,18 +168,6 @@ class QNLI(Task): ...@@ -168,18 +168,6 @@ class QNLI(Task):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def process_results(self, doc, results):
ll_yes, ll_no = results
pred = ll_no > ll_yes
gold = doc["label"]
return {"acc": pred == gold}
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
class WNLI(PromptSourceTask): class WNLI(PromptSourceTask):
VERSION = 1 VERSION = 1
...@@ -196,20 +184,11 @@ class WNLI(PromptSourceTask): ...@@ -196,20 +184,11 @@ class WNLI(PromptSourceTask):
return False return False
def training_docs(self): def training_docs(self):
# if self._training_docs is None:
# self._training_docs = list()
# return self._training_docs
return self.dataset["train"] return self.dataset["train"]
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
class RTE(PromptSourceTask): class RTE(PromptSourceTask):
VERSION = 0 VERSION = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment