ko_translation.py 5.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
NOTE: This file implements translation tasks using datasets from https://huggingface.co/datasets/Moo/korean-parallel-corpora

"""

from datasets import load_dataset
from lm_eval import metrics
from lm_eval.base import Task, rf


########################################
# DATASET Specifics
########################################
DATASET_PATH: str = "Moo/korean-parallel-corpora"


class KoreanTranslationTask(Task):
    VERSION = 0

    def __init__(self):
        pass

    def has_training_docs(self):
        """Whether the task has a training set"""
        return True

    def has_validation_docs(self):
        """Whether the task has a validation set"""
        return True

    def has_test_docs(self):
        """Whether the task has a test set"""
        return True

    def training_docs(self):
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
        if self._training_docs is None:
            self._training_docs = [
                {"src": src, "tgt": tgt} for src, tgt in zip(self.train_src, self.train_tgt)
                ]

        return self._training_docs

    def validation_docs(self):
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
        return [
            {"src": src, "tgt": tgt} for src, tgt in zip(self.valid_src, self.valid_tgt)
            ]

    def test_docs(self):
        """
        :return: Iterable[obj]
            A iterable of any object, that doc_to_text can handle
        """
        return [
            {"src": src, "tgt": tgt} for src, tgt in zip(self.tst_src, self.tst_tgt)
            ]
  
    def doc_to_text(self, doc):
        src_lang = self.src_lang
        tar_lang = self.tar_lang
        if src_lang == 'ko':
            return f"{src_lang}{tar_lang}으로 번역해주는 모델입니다.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:"
        elif src_lang == 'en':
            return f"Translate {src_lang} to {tar_lang}.\n\n###\n{src_lang}:" + doc["src"] + f"\n{tar_lang}:"
            
    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["src"]

    def doc_to_target(self, doc):
        # This shows a single target, though there may be multiple targets in a lang test
        return " " + doc["tgt"] if isinstance(doc["tgt"], str) else doc["tgt"][0]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        return rf.greedy_until(ctx, ["\n"])

    def process_results(self, doc, results):
        ref_pred = (doc["tgt"], results)
        return {
            "bleu": ref_pred,
            "chrf": ref_pred,
            "ter": ref_pred,
        }

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            "bleu": metrics.bleu,
            "chrf": metrics.chrf,
            "ter": metrics.ter,
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            "bleu": True,
            "chrf": True,
            "ter": False,
        }

    def __str__(self):
        return f"{self.src_lang} to {self.tar_lang} Task"


class KoEnTranslation(KoreanTranslationTask):
    def __init__(self):
        super().__init__()
        self.dataset = load_dataset(DATASET_PATH)

        self.src_lang = 'ko'
        self.tar_lang = 'en'
        
        self.train_src = list(self.dataset['train'][self.src_lang])
        self.train_tgt = list(self.dataset['train'][self.tar_lang])
        self.valid_src = list(self.dataset['validation'][self.src_lang])
        self.valid_tgt = list(self.dataset['validation'][self.tar_lang])
        self.tst_src = list(self.dataset['test'][self.src_lang])
        self.tst_tgt = list(self.dataset['test'][self.tar_lang])
        self._training_docs = None
        self._fewshot_docs = None


class EnKoTranslation(KoreanTranslationTask):
    def __init__(self):
        super().__init__()
        self.dataset = load_dataset(DATASET_PATH)
        self.src_lang = 'en'
        self.tar_lang = 'ko'
        
        self.train_src = list(self.dataset['train'][self.src_lang])
        self.train_tgt = list(self.dataset['train'][self.tar_lang])
        self.valid_src = list(self.dataset['validation'][self.src_lang])
        self.valid_tgt = list(self.dataset['validation'][self.tar_lang])
        self.tst_src = list(self.dataset['test'][self.src_lang])
        self.tst_tgt = list(self.dataset['test'][self.tar_lang])
        
        self._training_docs = None
        self._fewshot_docs = None