Commit dd317a75 authored by ingyuseong's avatar ingyuseong
Browse files

Update KLUE-MRC metric (F1 and EM)

parent e8f38aee
...@@ -13,6 +13,7 @@ https://arxiv.org/abs/2105.09680 ...@@ -13,6 +13,7 @@ https://arxiv.org/abs/2105.09680
""" """
import datasets import datasets
import evaluate
from math import exp from math import exp
import numpy as np import numpy as np
from lm_eval.base import Task, MultipleChoiceTask, rf from lm_eval.base import Task, MultipleChoiceTask, rf
...@@ -32,16 +33,16 @@ _CITATION = """ ...@@ -32,16 +33,16 @@ _CITATION = """
""" """
def _squad_metric(predictions, references): def _klue_mrc_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2") klue_mrc_metric = evaluate.load("ingyu/klue_mrc")
return squad_metric.compute(predictions=predictions, references=references) return klue_mrc_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items): def _klue_mrc_agg(key, items):
predictions, references = zip(*items) predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key] return _klue_mrc_metric(predictions=predictions, references=references)[key]
class STS(Task): class STS(Task):
...@@ -231,7 +232,7 @@ class MRC(Task): ...@@ -231,7 +232,7 @@ class MRC(Task):
return self.dataset["validation"] return self.dataset["validation"]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:' return "제목: " + doc["title"] + "\n\n" + "본문: " + doc["context"] + "\n\n" + "질문: " + doc["question"] + "\n\n" + "답:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
answer = doc["answers"]["text"][0] answer = doc["answers"]["text"][0]
...@@ -250,7 +251,7 @@ class MRC(Task): ...@@ -250,7 +251,7 @@ class MRC(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
continuation = rf.greedy_until(ctx, ['\n']) continuation = rf.greedy_until(ctx, {"until": ["\n"]})
is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가") is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가")
return continuation, is_unanswerable return continuation, is_unanswerable
...@@ -320,28 +321,28 @@ class MRC(Task): ...@@ -320,28 +321,28 @@ class MRC(Task):
""" """
return { return {
"exact": partial( "exact": partial(
_squad_agg, "exact" _klue_mrc_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer) ), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial( "f1": partial(
_squad_agg, "f1" _klue_mrc_agg, "f1"
), # The F-score of predicted tokens versus the gold answer ), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial( "HasAns_exact": partial(
_squad_agg, "HasAns_exact" _klue_mrc_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer) ), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial( "HasAns_f1": partial(
_squad_agg, "HasAns_f1" _klue_mrc_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer ), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial( "NoAns_exact": partial(
_squad_agg, "NoAns_exact" _klue_mrc_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer) ), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial( "NoAns_f1": partial(
_squad_agg, "NoAns_f1" _klue_mrc_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer ), # The F-score of predicted tokens versus the gold answer
"best_exact": partial( "best_exact": partial(
_squad_agg, "best_exact" _klue_mrc_agg, "best_exact"
), # Best exact match (with varying threshold) ), # Best exact match (with varying threshold)
"best_f1": partial( "best_f1": partial(
_squad_agg, "best_f1" _klue_mrc_agg, "best_f1"
), # Best F1 (with varying threshold) ), # Best F1 (with varying threshold)
} }
......
...@@ -42,7 +42,7 @@ setuptools.setup( ...@@ -42,7 +42,7 @@ setuptools.setup(
], ],
extras_require={ extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"], "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"], "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1", "evaluate>=0.4.0"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"], "sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment