Unverified Commit 10c61137 authored by Seung-Moo Yang's avatar Seung-Moo Yang Committed by GitHub
Browse files

Add ko-en translation task to eval Korean language task (#328)



* feat: Add task for ko-en translation
Co-authored-by: default avatarTaekyoon <tgchoi03@gmail.com>
parent 7ce584bd
......@@ -54,12 +54,14 @@ from . import storycloze
from . import kobest
from . import nsmc
from . import klue
from . import ko_translation
from . import korquad
########################################
# Translation tasks
########################################
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
......@@ -307,14 +309,16 @@ TASK_REGISTRY = {
# "sat": sat.SATAnalogies,
"klue_sts": klue.STS,
"klue_ynat": klue.YNAT
"klue_ynat": klue.YNAT,
"nsmc": nsmc.NSMC,
"korquad": korquad.Korquad
"korquad": korquad.Korquad,
"kobest_boolq": kobest.BoolQ,
"kobest_copa": kobest.COPA,
"kobest_wic": kobest.WiC,
"kobest_hellaswag": kobest.HellaSwag,
"kobest_sentineg": kobest.SentiNeg
"kobest_sentineg": kobest.SentiNeg,
"ko_en_translation": ko_translation.KoEnTranslation,
"en_ko_translation": ko_translation.EnKoTranslation
}
......
"""
NOTE: This file implements translation tasks using datasets from https://huggingface.co/datasets/Moo/korean-parallel-corpora
"""
from datasets import load_dataset
from lm_eval import metrics
from lm_eval.base import Task, rf
########################################
# DATASET Specifics
########################################
DATASET_PATH: str = "Moo/korean-parallel-corpora"
class KoreanTranslationTask(Task):
VERSION = 0
def __init__(self):
pass
def has_training_docs(self):
"""Whether the task has a training set"""
return True
def has_validation_docs(self):
"""Whether the task has a validation set"""
return True
def has_test_docs(self):
"""Whether the task has a test set"""
return True
def training_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if self._training_docs is None:
self._training_docs = [
{"src": src, "tgt": tgt} for src, tgt in zip(self.train_src, self.train_tgt)
]
return self._training_docs
def validation_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [
{"src": src, "tgt": tgt} for src, tgt in zip(self.valid_src, self.valid_tgt)
]
def test_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [
{"src": src, "tgt": tgt} for src, tgt in zip(self.tst_src, self.tst_tgt)
]
def doc_to_text(self, doc):
src_lang = self.src_lang
tar_lang = self.tar_lang
if src_lang == 'ko':
return f"{src_lang}{tar_lang}으로 번역해주는 모델입니다.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:"
elif src_lang == 'en':
return f"Translate {src_lang} to {tar_lang}.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["src"]
def doc_to_target(self, doc):
# This shows a single target, though there may be multiple targets in a lang test
return " " + doc["tgt"] if isinstance(doc["tgt"], str) else doc["tgt"][0]
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
ref_pred = (doc["tgt"], results)
return {
"bleu": ref_pred,
"chrf": ref_pred,
"ter": ref_pred,
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"bleu": True,
"chrf": True,
"ter": False,
}
def __str__(self):
return f"{self.src_lang} to {self.tar_lang} Task"
class KoEnTranslation(KoreanTranslationTask):
def __init__(self):
super().__init__()
self.dataset = load_dataset(DATASET_PATH)
self.src_lang = 'ko'
self.tar_lang = 'en'
self.train_src = list(self.dataset['train'][self.src_lang])
self.train_tgt = list(self.dataset['train'][self.tar_lang])
self.valid_src = list(self.dataset['validation'][self.src_lang])
self.valid_tgt = list(self.dataset['validation'][self.tar_lang])
self.tst_src = list(self.dataset['test'][self.src_lang])
self.tst_tgt = list(self.dataset['test'][self.tar_lang])
self._training_docs = None
self._fewshot_docs = None
class EnKoTranslation(KoreanTranslationTask):
def __init__(self):
super().__init__()
self.dataset = load_dataset(DATASET_PATH)
self.src_lang = 'en'
self.tar_lang = 'ko'
self.train_src = list(self.dataset['train'][self.src_lang])
self.train_tgt = list(self.dataset['train'][self.tar_lang])
self.valid_src = list(self.dataset['validation'][self.src_lang])
self.valid_tgt = list(self.dataset['validation'][self.tar_lang])
self.tst_src = list(self.dataset['test'][self.src_lang])
self.tst_tgt = list(self.dataset['test'][self.tar_lang])
self._training_docs = None
self._fewshot_docs = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment