Commit dd317a75 authored by ingyuseong's avatar ingyuseong
Browse files

Update KLUE-MRC metric (F1 and EM)

parent e8f38aee
......@@ -13,6 +13,7 @@ https://arxiv.org/abs/2105.09680
"""
import datasets
import evaluate
from math import exp
import numpy as np
from lm_eval.base import Task, MultipleChoiceTask, rf
......@@ -32,16 +33,16 @@ _CITATION = """
"""
def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
def _klue_mrc_metric(predictions, references):
klue_mrc_metric = evaluate.load("ingyu/klue_mrc")
return squad_metric.compute(predictions=predictions, references=references)
return klue_mrc_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
def _klue_mrc_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
return _klue_mrc_metric(predictions=predictions, references=references)[key]
class STS(Task):
......@@ -231,7 +232,7 @@ class MRC(Task):
return self.dataset["validation"]
def doc_to_text(self, doc):
return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:'
return "제목: " + doc["title"] + "\n\n" + "본문: " + doc["context"] + "\n\n" + "질문: " + doc["question"] + "\n\n" + "답:"
def doc_to_target(self, doc):
answer = doc["answers"]["text"][0]
......@@ -250,7 +251,7 @@ class MRC(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ['\n'])
continuation = rf.greedy_until(ctx, {"until": ["\n"]})
is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가")
return continuation, is_unanswerable
......@@ -320,28 +321,28 @@ class MRC(Task):
"""
return {
"exact": partial(
_squad_agg, "exact"
_klue_mrc_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
_squad_agg, "f1"
_klue_mrc_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial(
_squad_agg, "HasAns_exact"
_klue_mrc_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
_klue_mrc_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
_klue_mrc_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
_klue_mrc_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
_klue_mrc_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
_klue_mrc_agg, "best_f1"
), # Best F1 (with varying threshold)
}
......
......@@ -42,7 +42,7 @@ setuptools.setup(
],
extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1", "evaluate>=0.4.0"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment