Commit ce823691 authored by ingyuseong's avatar ingyuseong
Browse files

Add KLUE-MRC task (same metric w/ SQuAD 2.0)

parent ceaa7ad3
...@@ -19,6 +19,8 @@ from lm_eval.base import Task, MultipleChoiceTask, rf ...@@ -19,6 +19,8 @@ from lm_eval.base import Task, MultipleChoiceTask, rf
from lm_eval.metrics import macro_f1_score, mean, matthews_corrcoef, f1_score, yesno from lm_eval.metrics import macro_f1_score, mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize from lm_eval.utils import general_detokenize
from functools import partial from functools import partial
from sys import exit
from lm_eval.tasks.datasets.metrics.squad_v2.squad_v2 import SquadV2 as squad_metric
_CITATION = """ _CITATION = """
@misc{park2021klue, @misc{park2021klue,
...@@ -33,8 +35,10 @@ _CITATION = """ ...@@ -33,8 +35,10 @@ _CITATION = """
def _squad_metric(predictions, references): def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad") # squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
# return squad_metric.compute(predictions=predictions, references=references)
return squad_metric._compute(squad_metric, predictions=predictions, references=references)
def _squad_agg(key, items): def _squad_agg(key, items):
...@@ -233,7 +237,9 @@ class MRC(Task): ...@@ -233,7 +237,9 @@ class MRC(Task):
return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:' return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:'
def doc_to_target(self, doc): def doc_to_target(self, doc):
answer = "대답 불가" if doc["is_impossible"] else doc["answers"]["text"][0] answer = doc["answers"]["text"][0]
if doc["is_impossible"]:
answer = "대답 불가"
return " " + answer return " " + answer
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
...@@ -247,7 +253,7 @@ class MRC(Task): ...@@ -247,7 +253,7 @@ class MRC(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
continuation = rf.greedy_until(ctx, ["\n"]) continuation = rf.greedy_until(ctx, ['\n'])
is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가") is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가")
return continuation, is_unanswerable return continuation, is_unanswerable
...@@ -266,19 +272,47 @@ class MRC(Task): ...@@ -266,19 +272,47 @@ class MRC(Task):
no_answer_probability = exp(logprob_unanswerable) no_answer_probability = exp(logprob_unanswerable)
predictions = { predictions = {
"id": doc["guid"], 'id': doc['guid'],
"prediction_text": continuation, 'prediction_text': continuation,
"no_answer_probability": no_answer_probability 'no_answer_probability': no_answer_probability,
} }
references = { references = {
"id": doc["guid"], 'id': doc['guid'],
"answers": doc["answers"], 'answers': doc['answers'],
'unanswerable': doc['is_impossible'],
} }
return { return {
'exact_match': (predictions, references), # Exact match (the normalized answer exactly match the gold answer) "exact": (
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"best_exact": (
predictions,
references,
), # Best exact match (with varying threshold)
"best_f1": (predictions, references), # Best F1 (with varying threshold)
} }
def aggregation(self): def aggregation(self):
...@@ -287,9 +321,31 @@ class MRC(Task): ...@@ -287,9 +321,31 @@ class MRC(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
return { return {
'exact_match': partial(_squad_agg, 'exact_match'), # Exact match (the normalized answer exactly match the gold answer) "exact": partial(
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer _squad_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
_squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial(
_squad_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
), # Best F1 (with varying threshold)
} }
def higher_is_better(self): def higher_is_better(self):
...@@ -298,7 +354,13 @@ class MRC(Task): ...@@ -298,7 +354,13 @@ class MRC(Task):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
return { return {
'exact_match': True, # Exact match (the normalized answer exactly match the gold answer) "exact": True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer "f1": True, # The F-score of predicted tokens versus the gold answer
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
"best_exact": True, # Best exact match (with varying threshold)
"best_f1": True, # Best F1 (with varying threshold)
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment