Commit a426a39d authored by danny980521's avatar danny980521
Browse files

Update `KLUE-YNAT` prompt

- add a prompt for the klue-ynat task
- remove unnecessary parenthesis from example answers
- reformat the code
parent e8f38aee
......@@ -68,8 +68,7 @@ class STS(Task):
def doc_to_text(self, doc):
return "질문: 문장 1과 문장 2는 서로 유사한 의미를 가지나요?\n문장 1: {}\n문장 2: {}\n정답:".format(
general_detokenize(doc["sentence1"]),
general_detokenize(doc["sentence2"])
general_detokenize(doc["sentence1"]), general_detokenize(doc["sentence2"])
)
def doc_to_target(self, doc):
......@@ -83,22 +82,13 @@ class STS(Task):
def process_results(self, doc, results):
pred = np.argmax(results)
gold = doc["labels"]["binary-label"]
return {
"acc": pred == gold,
"f1": (gold, pred)
}
return {"acc": pred == gold, "f1": (gold, pred)}
def higher_is_better(self):
return {
"acc": True,
"f1": True
}
return {"acc": True, "f1": True}
def aggregation(self):
return {
"acc": mean,
"f1": f1_score
}
return {"acc": mean, "f1": f1_score}
class YNAT(MultipleChoiceTask):
......@@ -117,7 +107,7 @@ class YNAT(MultipleChoiceTask):
def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc,self.dataset["train"]))
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs
def validation_docs(self):
......@@ -127,32 +117,30 @@ class YNAT(MultipleChoiceTask):
out_doc = {
"title": doc["title"],
"choices": ["과학", "경제", "사회", "생활", "세계", "스포츠", "정치"],
"gold": doc["label"]
"gold": doc["label"],
}
return out_doc
def doc_to_text(self, doc):
return "{}".format(doc["title"])
return "질문: 다음의 제목을 가지는 뉴스는 어느 분야의 뉴스인가요?\n제목: {}\n분야:".format(doc["title"])
def doc_to_target(self, doc):
return " ({})".format({0: "과학", 1: "경제", 2: "사회", 3: "생활", 4: "세계", 5: "스포츠", 6: "정치"}[doc["gold"]])
return " {}".format(
{0: "과학", 1: "경제", 2: "사회", 3: "생활", 4: "세계", 5: "스포츠", 6: "정치"}[
doc["gold"]
]
)
def process_results(self, doc, results):
pred = np.argmax(results)
gold = doc["gold"]
return {
"f1": (gold, pred)
}
return {"f1": (gold, pred)}
def higher_is_better(self):
return {
"f1": True
}
return {"f1": True}
def aggregation(self):
return {
"f1": macro_f1_score
}
return {"f1": macro_f1_score}
class NLI(Task):
......@@ -231,7 +219,18 @@ class MRC(Task):
return self.dataset["validation"]
def doc_to_text(self, doc):
return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:'
return (
"제목: "
+ doc["title"]
+ "\n\n"
+ "본문: "
+ doc["context"]
+ "\n\n"
+ "질문: "
+ doc["question"]
+ "\n\n"
+ "답:"
)
def doc_to_target(self, doc):
answer = doc["answers"]["text"][0]
......@@ -240,7 +239,7 @@ class MRC(Task):
return " " + answer
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
......@@ -250,7 +249,7 @@ class MRC(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ['\n'])
continuation = rf.greedy_until(ctx, ["\n"])
is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가")
return continuation, is_unanswerable
......@@ -269,15 +268,15 @@ class MRC(Task):
no_answer_probability = exp(logprob_unanswerable)
predictions = {
'id': doc['guid'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
"id": doc["guid"],
"prediction_text": continuation,
"no_answer_probability": no_answer_probability,
}
references = {
'id': doc['guid'],
'answers': doc['answers'],
'unanswerable': doc['is_impossible'],
"id": doc["guid"],
"answers": doc["answers"],
"unanswerable": doc["is_impossible"],
}
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment