"...resnet50_tensorflow.git" did not exist on "d4dd827f8b5686f427d8ed03198cde2f922eb90e"
Commit 4c201b97 authored by cjlovering's avatar cjlovering
Browse files

SST with PS integration. (It was already done.)

parent e49cf8da
...@@ -71,14 +71,6 @@ class CoLA(PromptSourceTask): ...@@ -71,14 +71,6 @@ class CoLA(PromptSourceTask):
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
pred = np.argmax(results) pred = np.argmax(results)
target = answer_choices_list.index(self.doc_to_target(doc).strip()) target = answer_choices_list.index(self.doc_to_target(doc).strip())
print("*" * 80)
print(f"DOC: {doc}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"STRING TARGET: {self.doc_to_target(doc)} END TARGET")
print(f"TARGET: {target} END TARGET")
print(f"PRED: {pred}")
print("*" * 80)
return {"mcc": (target, pred)} return {"mcc": (target, pred)}
def higher_is_better(self): def higher_is_better(self):
...@@ -141,17 +133,6 @@ class MNLI(PromptSourceTask): ...@@ -141,17 +133,6 @@ class MNLI(PromptSourceTask):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["test_matched"] return self.dataset["test_matched"]
def process_results(self, doc, results):
gold = doc["label"]
pred = np.argmax(results)
return {"acc": pred == gold}
def higher_is_better(self):
return {"acc": True}
def aggregation(self):
return {"acc": mean}
class MNLIMismatched(MNLI): class MNLIMismatched(MNLI):
VERSION = 0 VERSION = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment