Commit ce823691 authored by ingyuseong's avatar ingyuseong
Browse files

Add KLUE-MRC task (same metric w/ SQuAD 2.0)

parent ceaa7ad3
......@@ -19,6 +19,8 @@ from lm_eval.base import Task, MultipleChoiceTask, rf
from lm_eval.metrics import macro_f1_score, mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize
from functools import partial
from sys import exit
from lm_eval.tasks.datasets.metrics.squad_v2.squad_v2 import SquadV2 as squad_metric
_CITATION = """
@misc{park2021klue,
......@@ -33,8 +35,10 @@ _CITATION = """
def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad")
return squad_metric.compute(predictions=predictions, references=references)
# squad_metric = datasets.load_metric("squad_v2")
# return squad_metric.compute(predictions=predictions, references=references)
return squad_metric._compute(squad_metric, predictions=predictions, references=references)
def _squad_agg(key, items):
......@@ -233,7 +237,9 @@ class MRC(Task):
return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:'
def doc_to_target(self, doc):
answer = "대답 불가" if doc["is_impossible"] else doc["answers"]["text"][0]
answer = doc["answers"]["text"][0]
if doc["is_impossible"]:
answer = "대답 불가"
return " " + answer
def construct_requests(self, doc, ctx):
......@@ -247,7 +253,7 @@ class MRC(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ["\n"])
continuation = rf.greedy_until(ctx, ['\n'])
is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가")
return continuation, is_unanswerable
......@@ -266,19 +272,47 @@ class MRC(Task):
no_answer_probability = exp(logprob_unanswerable)
predictions = {
"id": doc["guid"],
"prediction_text": continuation,
"no_answer_probability": no_answer_probability
'id': doc['guid'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
}
references = {
"id": doc["guid"],
"answers": doc["answers"],
'id': doc['guid'],
'answers': doc['answers'],
'unanswerable': doc['is_impossible'],
}
return {
'exact_match': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
return {
"exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"best_exact": (
predictions,
references,
), # Best exact match (with varying threshold)
"best_f1": (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self):
......@@ -287,9 +321,31 @@ class MRC(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
'exact_match': partial(_squad_agg, 'exact_match'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
return {
"exact": partial(
_squad_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
_squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial(
_squad_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
), # Best F1 (with varying threshold)
}
def higher_is_better(self):
......@@ -298,7 +354,13 @@ class MRC(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
'exact_match': True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer
return {
"exact": True, # Exact match (the normalized answer exactly match the gold answer)
"f1": True, # The F-score of predicted tokens versus the gold answer
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
"best_exact": True, # Best exact match (with varying threshold)
"best_f1": True, # Best F1 (with varying threshold)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment