utils.py 8.25 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import argparse
lintangsutawika's avatar
lintangsutawika committed
2

3
4
import yaml

lintangsutawika's avatar
lintangsutawika committed
5
6
7

LANGUAGES = {
    "bn": {  # Bengali
8
        # "QUESTION": "প্রশ্ন:",
lintangsutawika's avatar
lintangsutawika committed
9
        "QUESTION": "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:",
10
        # "ANSWER": "ধাপে ধাপে উত্তর:",
lintangsutawika's avatar
lintangsutawika committed
11
        "ANSWER": "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:",
lintangsutawika's avatar
lintangsutawika committed
12
13
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
14
15
16
    },
    "de": {  # German
        "QUESTION": "Frage:",
17
        # "ANSWER": "Schritt-für-Schritt-Antwort:",
lintangsutawika's avatar
lintangsutawika committed
18
        "ANSWER": "Schritt-f\u00fcr-Schritt-Antwort:",
lintangsutawika's avatar
lintangsutawika committed
19
        "DIRECT": "Antwort:",
20
        "REGEX": "Die Antwort lautet (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
21
22
23
24
    },
    "en": {  # English
        "QUESTION": "Question:",
        "ANSWER": "Step-by-Step Answer:",
lintangsutawika's avatar
lintangsutawika committed
25
26
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
27
28
29
30
    },
    "es": {  # Spanish
        "QUESTION": "Pregunta:",
        "ANSWER": "Respuesta paso a paso:",
31
32
        "DIRECT": "Respuesta:",
        "REGEX": "La respuesta es (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
33
34
35
    },
    "fr": {  # French
        "QUESTION": "Question :",
36
        # "ANSWER": "Réponse étape par étape :"
lintangsutawika's avatar
lintangsutawika committed
37
        "ANSWER": "R\u00e9ponse \u00e9tape par \u00e9tape :",
38
        # "DIRECT": "Réponse :",
lintangsutawika's avatar
lintangsutawika committed
39
        "DIRECT": "R\u00e9ponse :",
40
41
        # "REGEX": "La réponse est (\\-?[0-9\\.\\,]+)",
        "REGEX": "La r\u00e9ponse est (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
42
43
    },
    "ru": {  # Russian
44
        # "QUESTION": "Задача:",
lintangsutawika's avatar
lintangsutawika committed
45
        "QUESTION": "\u0417\u0430\u0434\u0430\u0447\u0430:",
46
        # "ANSWER": "Пошаговоерешение:",
lintangsutawika's avatar
lintangsutawika committed
47
        "ANSWER": "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:",
lintangsutawika's avatar
lintangsutawika committed
48
        "DIRECT": "Answer:",
49
50
        # "REGEX": "Ответ — (\\-?[0-9\\.\\,]+)",
        "REGEX": "\u041e\u0442\u0432\u0435\u0442 \u2014 (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
51
52
53
54
    },
    "sw": {  # Swahili
        "QUESTION": "Swali:",
        "ANSWER": "Jibu la Hatua kwa Hatua:",
lintangsutawika's avatar
lintangsutawika committed
55
        "DIRECT": "Answer:",
56
        "REGEX": "Jibu ni (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
57
58
    },
    "te": {  # Telugu
59
        # "QUESTION": "ప్రశ్న:",
lintangsutawika's avatar
lintangsutawika committed
60
        "QUESTION": "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:",
61
        # "ANSWER": "దశలవారీగా సమాధానం:",
lintangsutawika's avatar
lintangsutawika committed
62
        "ANSWER": "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:",
lintangsutawika's avatar
lintangsutawika committed
63
        "DIRECT": "Answer:",
64
65
        # "REGEX": "సమాధానం (\\-?[0-9\\.\\,]+)",
        "REGEX": "\u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02 (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
66
67
    },
    "th": {  # Thai
68
        # "QUESTION": "โจทย์:",
lintangsutawika's avatar
lintangsutawika committed
69
        "QUESTION": "\u0e42\u0e08\u0e17\u0e22\u0e4c:",
70
        # "ANSWER": "คำตอบทีละขั้นตอน:",
lintangsutawika's avatar
lintangsutawika committed
71
        "ANSWER": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:",
lintangsutawika's avatar
lintangsutawika committed
72
        "DIRECT": "Answer:",
73
74
        # "REGEX": "คำตอบคือ (\\-?[0-9\\.\\,]+)",
        "REGEX": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e04\u0e37\u0e2d (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
75
76
    },
    "ja": {  # Japanese
77
        # "QUESTION": "問題:",
78
        "QUESTION": "\u554f\u984c:",
79
        # "ANSWER": "ステップごとの答え:",
lintangsutawika's avatar
lintangsutawika committed
80
        "ANSWER": "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:",
lintangsutawika's avatar
lintangsutawika committed
81
        "DIRECT": "Answer:",
82
83
        # "REGEX": "答えは(\\-?[0-9\\.\\,]+)です。",
        "REGEX": "\u7b54\u3048\u306f(\\-?[0-9\\.\\,]+)\u3067\u3059\u3002",
lintangsutawika's avatar
lintangsutawika committed
84
85
    },
    "zh": {  # Chinese
86
        # "QUESTION": "问题:",
87
        "QUESTION": "\u95ee\u9898:",
88
        # "ANSWER": "逐步解答:",
lintangsutawika's avatar
lintangsutawika committed
89
        "ANSWER": "\u9010\u6b65\u89e3\u7b54:",
lintangsutawika's avatar
lintangsutawika committed
90
        "DIRECT": "Answer:",
91
92
        # "REGEX": "答案是 (\\-?[0-9\\.\\,]+)。",
        "REGEX": "\u7b54\u6848\u662f (\\-?[0-9\\.\\,]+)\u3002",
lintangsutawika's avatar
lintangsutawika committed
93
94
95
    },
}

lintangsutawika's avatar
lintangsutawika committed
96

lintangsutawika's avatar
lintangsutawika committed
97
98
99
100
def add_regex_pattern(regex_pattern):
    if regex_pattern is None:
        return {}
    return {
lintangsutawika's avatar
lintangsutawika committed
101
102
        "filter_list": [
            {
103
                "name": "strict-match",
104
105
106
                "filter": [
                    {
                        "function": "regex",
107
108
109
110
111
112
113
114
115
116
117
118
119
120
                        "regex_pattern": f"""{regex_pattern}""",
                    },
                    {
                        "function": "take_first",
                    },
                ],
            },
            {
                "name": "flexible-extract",
                "filter": [
                    {
                        "function": "regex",
                        "regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""",
                        "group_select": -1,
121
122
123
124
125
                    },
                    {
                        "function": "take_first",
                    },
                ],
lintangsutawika's avatar
lintangsutawika committed
126
127
128
            },
        ],
    }
lintangsutawika's avatar
lintangsutawika committed
129

130

lintangsutawika's avatar
lintangsutawika committed
131
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
lintangsutawika's avatar
lintangsutawika committed
132
133
134
135
136
137
138
139
140
141
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    for lang in LANGUAGES.keys():
        try:
            QUESTION = LANGUAGES[lang]["QUESTION"]
lintangsutawika's avatar
lintangsutawika committed
142
143

            yaml_template = "cot_yaml"
144
            filter_list = {}
145
            DELIMITER = None
lintangsutawika's avatar
lintangsutawika committed
146
147
148
            if mode == "direct":
                ANSWER = LANGUAGES[lang]["DIRECT"]
                REGEX = None
149
                task_name = f"mgsm_direct_{lang}"
lintangsutawika's avatar
lintangsutawika committed
150
151
152
153
                yaml_template = "direct_yaml"
            elif mode == "native-cot":
                ANSWER = LANGUAGES[lang]["ANSWER"]
                REGEX = LANGUAGES[lang]["REGEX"]
154
                task_name = f"mgsm_native_cot_{lang}"
155
                filter_list = add_regex_pattern(REGEX)
156
                DELIMITER = "" if lang in ["zh", "ja"] else None
lintangsutawika's avatar
lintangsutawika committed
157
            elif mode == "en-cot":
lintangsutawika's avatar
lintangsutawika committed
158
159
                ANSWER = LANGUAGES["en"]["ANSWER"]
                REGEX = LANGUAGES["en"]["REGEX"]
160
                task_name = f"mgsm_en_cot_{lang}"
lintangsutawika's avatar
lintangsutawika committed
161

lintangsutawika's avatar
lintangsutawika committed
162
            file_name = f"{task_name}.yaml"
163
            ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1
lintangsutawika's avatar
lintangsutawika committed
164
165
166
167
168
169
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
lintangsutawika's avatar
lintangsutawika committed
170
                        "include": yaml_template,
lintangsutawika's avatar
lintangsutawika committed
171
                        "dataset_name": lang,
172
                        "task": f"{task_name}",
lintangsutawika's avatar
lintangsutawika committed
173
174
175
176
177
178
                        "doc_to_text": f"""{{% if answer is not none %}}"""
                        f"""{{{{question+"\\n{ANSWER}"}}}}"""
                        f"""{{% else %}}"""
                        f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
                        f"""{{% endif %}}""",
                        "doc_to_target": f"""{{% if answer is not none %}}"""
179
                        f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}"""
lintangsutawika's avatar
lintangsutawika committed
180
181
182
183
                        f"""{{% else %}}"""
                        f"""{{{{answer_number|string}}}}"""
                        f"""{{% endif %}}""",
                        **filter_list,
184
185
                        "generation_kwargs": {
                            "until": [QUESTION, "</s>", "<|im_end|>"],
186
                            "do_sample": False,
187
                        },
188
                        **({"target_delimiter": DELIMITER} if DELIMITER else {}),
lintangsutawika's avatar
lintangsutawika committed
189
190
191
                    },
                    f,
                    allow_unicode=True,
192
                    width=float("inf"),
lintangsutawika's avatar
lintangsutawika committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=False,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory to write yaml files to"
    )
lintangsutawika's avatar
lintangsutawika committed
216
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
217
218
219
220
        "--mode",
        default="native-cot",
        choices=["direct", "native-cot", "en-cot"],
        help="Mode of chain-of-thought",
lintangsutawika's avatar
lintangsutawika committed
221
    )
lintangsutawika's avatar
lintangsutawika committed
222
223
    args = parser.parse_args()

lintangsutawika's avatar
lintangsutawika committed
224
    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)
lintangsutawika's avatar
lintangsutawika committed
225
226
227
228


if __name__ == "__main__":
    main()