""" NOTE: This file implements translation tasks using datasets from https://huggingface.co/datasets/Moo/korean-parallel-corpora """ from datasets import load_dataset from lm_eval import metrics from lm_eval.base import Task, rf ######################################## # DATASET Specifics ######################################## DATASET_PATH: str = "Moo/korean-parallel-corpora" class KoreanTranslationTask(Task): VERSION = 0 def __init__(self): pass def has_training_docs(self): """Whether the task has a training set""" return True def has_validation_docs(self): """Whether the task has a validation set""" return True def has_test_docs(self): """Whether the task has a test set""" return True def training_docs(self): """ :return: Iterable[obj] A iterable of any object, that doc_to_text can handle """ if self._training_docs is None: self._training_docs = [ {"src": src, "tgt": tgt} for src, tgt in zip(self.train_src, self.train_tgt) ] return self._training_docs def validation_docs(self): """ :return: Iterable[obj] A iterable of any object, that doc_to_text can handle """ return [ {"src": src, "tgt": tgt} for src, tgt in zip(self.valid_src, self.valid_tgt) ] def test_docs(self): """ :return: Iterable[obj] A iterable of any object, that doc_to_text can handle """ return [ {"src": src, "tgt": tgt} for src, tgt in zip(self.tst_src, self.tst_tgt) ] def doc_to_text(self, doc): src_lang = self.src_lang tar_lang = self.tar_lang if src_lang == 'ko': return f"{src_lang}을 {tar_lang}으로 번역해주는 모델입니다.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:" elif src_lang == 'en': return f"Translate {src_lang} to {tar_lang}.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:" def should_decontaminate(self): return True def doc_to_decontamination_query(self, doc): return doc["src"] def doc_to_target(self, doc): # This shows a single target, though there may be multiple targets in a lang test return " " + doc["tgt"] if isinstance(doc["tgt"], str) else doc["tgt"][0] def construct_requests(self, doc, ctx): """Uses RequestFactory to construct Requests and returns an iterable of Requests which will be sent to the LM. :param doc: The document as returned from training_docs, validation_docs, or test_docs. :param ctx: str The context string, generated by fewshot_context. This includes the natural language description, as well as the few shot examples, and the question part of the document for `doc`. """ return rf.greedy_until(ctx, ["\n"]) def process_results(self, doc, results): ref_pred = (doc["tgt"], results) return { "bleu": ref_pred, "chrf": ref_pred, "ter": ref_pred, } def aggregation(self): """ :returns: {str: [float] -> float} A dictionary where keys are the names of submetrics and values are functions that aggregate a list of metrics """ return { "bleu": metrics.bleu, "chrf": metrics.chrf, "ter": metrics.ter, } def higher_is_better(self): """ :returns: {str: bool} A dictionary where keys are the names of submetrics and values are whether a higher value of the submetric is better """ return { "bleu": True, "chrf": True, "ter": False, } def __str__(self): return f"{self.src_lang} to {self.tar_lang} Task" class KoEnTranslation(KoreanTranslationTask): def __init__(self): super().__init__() self.dataset = load_dataset(DATASET_PATH) self.src_lang = 'ko' self.tar_lang = 'en' self.train_src = list(self.dataset['train'][self.src_lang]) self.train_tgt = list(self.dataset['train'][self.tar_lang]) self.valid_src = list(self.dataset['validation'][self.src_lang]) self.valid_tgt = list(self.dataset['validation'][self.tar_lang]) self.tst_src = list(self.dataset['test'][self.src_lang]) self.tst_tgt = list(self.dataset['test'][self.tar_lang]) self._training_docs = None self._fewshot_docs = None class EnKoTranslation(KoreanTranslationTask): def __init__(self): super().__init__() self.dataset = load_dataset(DATASET_PATH) self.src_lang = 'en' self.tar_lang = 'ko' self.train_src = list(self.dataset['train'][self.src_lang]) self.train_tgt = list(self.dataset['train'][self.tar_lang]) self.valid_src = list(self.dataset['validation'][self.src_lang]) self.valid_tgt = list(self.dataset['validation'][self.tar_lang]) self.tst_src = list(self.dataset['test'][self.src_lang]) self.tst_tgt = list(self.dataset['test'][self.tar_lang]) self._training_docs = None self._fewshot_docs = None