Unverified Commit a77f4be9 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #536 from danny980521/update/klue_ynat

Update `KLUE-YNAT` prompt
parents a3b76ab1 d2dd333e
......@@ -69,8 +69,7 @@ class STS(Task):
def doc_to_text(self, doc):
return "질문: 문장 1과 문장 2는 서로 유사한 의미를 가지나요?\n문장 1: {}\n문장 2: {}\n정답:".format(
general_detokenize(doc["sentence1"]),
general_detokenize(doc["sentence2"])
general_detokenize(doc["sentence1"]), general_detokenize(doc["sentence2"])
)
def doc_to_target(self, doc):
......@@ -84,22 +83,13 @@ class STS(Task):
def process_results(self, doc, results):
pred = np.argmax(results)
gold = doc["labels"]["binary-label"]
return {
"acc": pred == gold,
"f1": (gold, pred)
}
return {"acc": pred == gold, "f1": (gold, pred)}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
return {"acc": True, "f1": True}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
return {"acc": mean, "f1": f1_score}
class YNAT(MultipleChoiceTask):
......@@ -118,7 +108,7 @@ class YNAT(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc,self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
......@@ -128,32 +118,30 @@ class YNAT(MultipleChoiceTask):
out_doc = {
"title": doc["title"],
"choices": ["과학", "경제", "사회", "생활", "세계", "스포츠", "정치"],
"gold": doc["label"]
"gold": doc["label"],
}
return out_doc
def doc_to_text(self, doc):
return "{}".format(doc["title"])
return "질문: 다음의 제목을 가지는 뉴스는 어느 분야의 뉴스인가요?\n제목: {}\n분야:".format(doc["title"])
def doc_to_target(self, doc):
return " ({})".format({0: "과학", 1: "경제", 2: "사회", 3: "생활", 4: "세계", 5: "스포츠", 6: "정치"}[doc["gold"]])
return " {}".format(
{0: "과학", 1: "경제", 2: "사회", 3: "생활", 4: "세계", 5: "스포츠", 6: "정치"}[
doc["gold"]
]
)
def process_results(self, doc, results):
pred = np.argmax(results)
gold = doc["gold"]
return {
"f1": (gold, pred)
}
return {"f1": (gold, pred)}
def higher_is_better(self):
return {
"f1": True
}
return {"f1": True}
def aggregation(self):
return {
"f1": macro_f1_score
}
return {"f1": macro_f1_score}
class NLI(Task):
......@@ -232,7 +220,18 @@ class MRC(Task):
return self.dataset["validation"]
def doc_to_text(self, doc):
return "제목: " + doc["title"] + "\n\n" + "본문: " + doc["context"] + "\n\n" + "질문: " + doc["question"] + "\n\n" + "답:"
return (
"제목: "
+ doc["title"]
+ "\n\n"
+ "본문: "
+ doc["context"]
+ "\n\n"
+ "질문: "
+ doc["question"]
+ "\n\n"
+ "답:"
)
def doc_to_target(self, doc):
answer = doc["answers"]["text"][0]
......@@ -241,7 +240,7 @@ class MRC(Task):
return " " + answer
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -251,7 +250,7 @@ class MRC(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, {"until": ["\n"]})
continuation = rf.greedy_until(ctx, ["\n"])
is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가")
return continuation, is_unanswerable
......@@ -270,15 +269,15 @@ class MRC(Task):
no_answer_probability = exp(logprob_unanswerable)
predictions = {
'id': doc['guid'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
"id": doc["guid"],
"prediction_text": continuation,
"no_answer_probability": no_answer_probability,
}
references = {
'id': doc['guid'],
'answers': doc['answers'],
'unanswerable': doc['is_impossible'],
"id": doc["guid"],
"answers": doc["answers"],
"unanswerable": doc["is_impossible"],
}
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment