Commit 16451834 authored by ingyuseong's avatar ingyuseong
Browse files

Add KLUE-MRC task (EM, F1)

parent a956bc63
......@@ -312,6 +312,7 @@ TASK_REGISTRY = {
"klue_sts": klue.STS,
"klue_ynat": klue.YNAT,
"klue_nli": klue.NLI,
"klue_mrc": klue.MRC,
"nsmc": nsmc.NSMC,
"korquad": korquad.Korquad,
"kobest_boolq": kobest.BoolQ,
......
......@@ -12,6 +12,8 @@ https://arxiv.org/abs/2105.09680
Homepage: https://klue-benchmark.com/
"""
import datasets
from math import exp
import numpy as np
from lm_eval.base import Task, MultipleChoiceTask, rf
from lm_eval.metrics import macro_f1_score, mean, matthews_corrcoef, f1_score, yesno
......@@ -29,6 +31,17 @@ _CITATION = """
"""
def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad")
return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
class STS(Task):
VERSION = 0
DATASET_PATH = "klue"
......@@ -106,7 +119,7 @@ class YNAT(MultipleChoiceTask):
return self._training_docs
def validation_docs(self):
return map(self._process_doc,self.dataset["validation"])
return map(self._process_doc, self.dataset["validation"])
def _process_doc(self, doc):
out_doc = {
......@@ -170,9 +183,11 @@ class NLI(Task):
)
def doc_to_target(self, doc):
# 참 = entailment
# 거짓 = contradiction
# 무관 = neutral
"""
참 = entailment
거짓 = contradiction
무관 = neutral
"""
return " {}".format({0: "참", 1: "중립", 2: "거짓"}[doc["label"]])
def construct_requests(self, doc, ctx):
......@@ -191,3 +206,98 @@ class NLI(Task):
def aggregation(self):
return {"acc": mean}
class MRC(Task):
VERSION = 0
DATASET_PATH = "klue"
DATASET_NAME = "mrc"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def doc_to_text(self, doc):
return '제목: ' + doc['title'] + '\n\n' + '본문: ' + doc['context'] + '\n\n' + '질문: ' + doc['question'] + '\n\n' + '답:'
def doc_to_target(self, doc):
answer = "대답 불가" if doc["is_impossible"] else doc["answers"]["text"][0]
return " " + answer
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, ["\n"])
is_unanswerable = rf.loglikelihood(ctx, " " + "대답 불가")
return continuation, is_unanswerable
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
continuation, (logprob_unanswerable, _) = results
no_answer_probability = exp(logprob_unanswerable)
predictions = {
"id": doc["guid"],
"prediction_text": continuation,
"no_answer_probability": no_answer_probability
}
references = {
"id": doc["guid"],
"answers": doc["answers"],
}
return {
'exact_match': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
'exact_match': partial(_squad_agg, 'exact_match'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
'exact_match': True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment