"vscode:/vscode.git/clone" did not exist on "51e77964756107c430c095688c29436d291aca8b"
Commit e49cf8da authored by cjlovering's avatar cjlovering
Browse files

SST with PS integration. (It was already done.)

parent 31a019c2
......@@ -67,7 +67,7 @@ class CoLA(PromptSourceTask):
def validation_docs(self):
return self.dataset["validation"]
def process_results(self, doc, results):
def process_results(self, doc, results):
answer_choices_list = self.prompt.get_answer_choices_list(doc)
pred = np.argmax(results)
target = answer_choices_list.index(self.doc_to_target(doc).strip())
......@@ -79,19 +79,13 @@ class CoLA(PromptSourceTask):
print(f"PRED: {pred}")
print("*" * 80)
return {
"mcc": (target, pred)
}
return {"mcc": (target, pred)}
def higher_is_better(self):
return {
"mcc": True
}
return {"mcc": True}
def aggregation(self):
return {
"mcc": matthews_corrcoef
}
return {"mcc": matthews_corrcoef}
class SST(PromptSourceTask):
......@@ -116,16 +110,6 @@ class SST(PromptSourceTask):
def validation_docs(self):
return self.dataset["validation"]
def higher_is_better(self):
return {
"acc": True
}
def aggregation(self):
return {
"acc": mean
}
# Inference Tasks
......@@ -160,19 +144,13 @@ class MNLI(PromptSourceTask):
def process_results(self, doc, results):
gold = doc["label"]
pred = np.argmax(results)
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
class MNLIMismatched(MNLI):
......@@ -213,19 +191,13 @@ class QNLI(Task):
ll_yes, ll_no = results
pred = ll_no > ll_yes
gold = doc["label"]
return {
"acc": pred == gold
}
return {"acc": pred == gold}
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
class WNLI(PromptSourceTask):
......@@ -252,14 +224,10 @@ class WNLI(PromptSourceTask):
return self.dataset["validation"]
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
class RTE(PromptSourceTask):
......@@ -285,14 +253,10 @@ class RTE(PromptSourceTask):
return self.dataset["validation"]
def higher_is_better(self):
return {
"acc": True
}
return {"acc": True}
def aggregation(self):
return {
"acc": mean
}
return {"acc": mean}
# Similarity and Paraphrase Tasks
......@@ -330,16 +294,10 @@ class MRPC(Task):
}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
return {"acc": True, "f1": True}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
return {"acc": mean, "f1": f1_score}
class QQP(Task):
......@@ -388,16 +346,10 @@ class QQP(Task):
}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
return {"acc": True, "f1": True}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
return {"acc": mean, "f1": f1_score}
class STSB(Task):
......@@ -435,22 +387,22 @@ class STSB(Task):
return " {}".format(doc["label"])
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -459,22 +411,22 @@ class STSB(Task):
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
raise NotImplementedError("Evaluation not implemented")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment