Unverified Commit e8f38aee authored by Yun Geonil's avatar Yun Geonil Committed by GitHub
Browse files

Merge pull request #515 from ingyuseong/feature/klue-mrc

Add `KLUE-MRC` task
parents 55f5de6b b8f59dee
......@@ -331,6 +331,7 @@ TASK_REGISTRY = {
"klue_sts": klue.STS,
"klue_ynat": klue.YNAT,
"klue_nli": klue.NLI,
"klue_mrc": klue.MRC,
"nsmc": nsmc.NSMC,
"korquad": korquad.Korquad,
"kobest_boolq": kobest.BoolQ,
......
......@@ -12,10 +12,13 @@ https://arxiv.org/abs/2105.09680
Homepage: https://klue-benchmark.com/
"""
import datasets
from math import exp
import numpy as np
from lm_eval.base import Task, MultipleChoiceTask, rf
from lm_eval.metrics import macro_f1_score, mean, matthews_corrcoef, f1_score, yesno
from lm_eval.utils import general_detokenize
from functools import partial
_CITATION = """
@misc{park2021klue,
......@@ -29,6 +32,18 @@ _CITATION = """
"""
def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
class STS(Task):
VERSION = 0
DATASET_PATH = "klue"
......@@ -106,7 +121,7 @@ class YNAT(MultipleChoiceTask):
return self._training_docs
def validation_docs(self):
return map(self._process_doc,self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
out_doc = {
......@@ -170,9 +185,11 @@ class NLI(Task):
)
def doc_to_target(self, doc):
# 참 = entailment
# 거짓 = contradiction
# 무관 = neutral
"""
참 = entailment
거짓 = contradiction
무관 = neutral
"""
return " {}".format({0: "참", 1: "중립", 2: "거짓"}[doc["label"]])
def construct_requests(self, doc, ctx):
......@@ -191,3 +208,156 @@ class NLI(Task):
def aggregation(self):
return {"acc": mean}
class MRC(Task):
VERSION = 0
DATASET_PATH = "klue"
DATASET_NAME = "mrc"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:'
def doc_to_target(self, doc):
answer = doc["answers"]["text"][0]
if doc["is_impossible"]:
answer = "대답 불가"
return " " + answer
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ['\n'])
is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가")
return continuation, is_unanswerable
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
continuation, (logprob_unanswerable, _) = results
no_answer_probability = exp(logprob_unanswerable)
predictions = {
'id': doc['guid'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
}
references = {
'id': doc['guid'],
'answers': doc['answers'],
'unanswerable': doc['is_impossible'],
}
return {
"exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"best_exact": (
predictions,
references,
), # Best exact match (with varying threshold)
"best_f1": (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"exact": partial(
_squad_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
_squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial(
_squad_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
), # Best F1 (with varying threshold)
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"exact": True, # Exact match (the normalized answer exactly match the gold answer)
"f1": True, # The F-score of predicted tokens versus the gold answer
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
"best_exact": True, # Best exact match (with varying threshold)
"best_f1": True, # Best F1 (with varying threshold)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment