Commit e49cf8da authored by cjlovering's avatar cjlovering
Browse files

SST with PS integration. (It was already done.)

parent 31a019c2
...@@ -67,7 +67,7 @@ class CoLA(PromptSourceTask): ...@@ -67,7 +67,7 @@ class CoLA(PromptSourceTask):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def process_results(self, doc, results): def process_results(self, doc, results):
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
pred = np.argmax(results) pred = np.argmax(results)
target = answer_choices_list.index(self.doc_to_target(doc).strip()) target = answer_choices_list.index(self.doc_to_target(doc).strip())
...@@ -79,19 +79,13 @@ class CoLA(PromptSourceTask): ...@@ -79,19 +79,13 @@ class CoLA(PromptSourceTask):
print(f"PRED: {pred}") print(f"PRED: {pred}")
print("*" * 80) print("*" * 80)
return { return {"mcc": (target, pred)}
"mcc": (target, pred)
}
def higher_is_better(self): def higher_is_better(self):
return { return {"mcc": True}
"mcc": True
}
def aggregation(self): def aggregation(self):
return { return {"mcc": matthews_corrcoef}
"mcc": matthews_corrcoef
}
class SST(PromptSourceTask): class SST(PromptSourceTask):
...@@ -116,16 +110,6 @@ class SST(PromptSourceTask): ...@@ -116,16 +110,6 @@ class SST(PromptSourceTask):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
# Inference Tasks # Inference Tasks
...@@ -160,19 +144,13 @@ class MNLI(PromptSourceTask): ...@@ -160,19 +144,13 @@ class MNLI(PromptSourceTask):
def process_results(self, doc, results): def process_results(self, doc, results):
gold = doc["label"] gold = doc["label"]
pred = np.argmax(results) pred = np.argmax(results)
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class MNLIMismatched(MNLI): class MNLIMismatched(MNLI):
...@@ -213,19 +191,13 @@ class QNLI(Task): ...@@ -213,19 +191,13 @@ class QNLI(Task):
ll_yes, ll_no = results ll_yes, ll_no = results
pred = ll_no > ll_yes pred = ll_no > ll_yes
gold = doc["label"] gold = doc["label"]
return { return {"acc": pred == gold}
"acc": pred == gold
}
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class WNLI(PromptSourceTask): class WNLI(PromptSourceTask):
...@@ -252,14 +224,10 @@ class WNLI(PromptSourceTask): ...@@ -252,14 +224,10 @@ class WNLI(PromptSourceTask):
return self.dataset["validation"] return self.dataset["validation"]
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
class RTE(PromptSourceTask): class RTE(PromptSourceTask):
...@@ -285,14 +253,10 @@ class RTE(PromptSourceTask): ...@@ -285,14 +253,10 @@ class RTE(PromptSourceTask):
return self.dataset["validation"] return self.dataset["validation"]
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True}
"acc": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean}
"acc": mean
}
# Similarity and Paraphrase Tasks # Similarity and Paraphrase Tasks
...@@ -330,16 +294,10 @@ class MRPC(Task): ...@@ -330,16 +294,10 @@ class MRPC(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "f1": f1_score}
"acc": mean,
"f1": f1_score
}
class QQP(Task): class QQP(Task):
...@@ -388,16 +346,10 @@ class QQP(Task): ...@@ -388,16 +346,10 @@ class QQP(Task):
} }
def higher_is_better(self): def higher_is_better(self):
return { return {"acc": True, "f1": True}
"acc": True,
"f1": True
}
def aggregation(self): def aggregation(self):
return { return {"acc": mean, "f1": f1_score}
"acc": mean,
"f1": f1_score
}
class STSB(Task): class STSB(Task):
...@@ -435,22 +387,22 @@ class STSB(Task): ...@@ -435,22 +387,22 @@ class STSB(Task):
return " {}".format(doc["label"]) return " {}".format(doc["label"])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param doc:
...@@ -459,22 +411,22 @@ class STSB(Task): ...@@ -459,22 +411,22 @@ class STSB(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError("Evaluation not implemented")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment