Commit 34c60176 authored by ingyuseong's avatar ingyuseong
Browse files

Refactor KLUE-STS calculating pred

parent 26e55d33
...@@ -61,13 +61,12 @@ class STS(Task): ...@@ -61,13 +61,12 @@ class STS(Task):
return " {}".format({0: "아니오", 1: "예"}[doc["labels"]["binary-label"]]) return " {}".format({0: "아니오", 1: "예"}[doc["labels"]["binary-label"]])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ll_positive, _ = rf.loglikelihood(ctx, " 예")
ll_negative, _ = rf.loglikelihood(ctx, " 아니오") ll_negative, _ = rf.loglikelihood(ctx, " 아니오")
return ll_positive, ll_negative ll_positive, _ = rf.loglikelihood(ctx, " 예")
return ll_negative, ll_positive
def process_results(self, doc, results): def process_results(self, doc, results):
ll_positive, ll_negative = results pred = np.argmax(results)
pred = ll_positive > ll_negative
gold = doc["labels"]["binary-label"] gold = doc["labels"]["binary-label"]
return { return {
"acc": pred == gold, "acc": pred == gold,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment