utils.py 7.41 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import yaml
lintangsutawika's avatar
lintangsutawika committed
2
import argparse
lintangsutawika's avatar
lintangsutawika committed
3
4
5
6


LANGUAGES = {
    "bn": {  # Bengali
7
        # "QUESTION": "প্রশ্ন:",
lintangsutawika's avatar
lintangsutawika committed
8
        "QUESTION": "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:",
9
        # "ANSWER": "ধাপে ধাপে উত্তর:",
lintangsutawika's avatar
lintangsutawika committed
10
        "ANSWER": "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:",
lintangsutawika's avatar
lintangsutawika committed
11
12
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
13
14
15
    },
    "de": {  # German
        "QUESTION": "Frage:",
16
        # "ANSWER": "Schritt-für-Schritt-Antwort:",
lintangsutawika's avatar
lintangsutawika committed
17
        "ANSWER": "Schritt-f\u00fcr-Schritt-Antwort:",
lintangsutawika's avatar
lintangsutawika committed
18
        "DIRECT": "Antwort:",
19
        "REGEX": "Die Antwort lautet (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
20
21
22
23
    },
    "en": {  # English
        "QUESTION": "Question:",
        "ANSWER": "Step-by-Step Answer:",
lintangsutawika's avatar
lintangsutawika committed
24
25
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
26
27
28
29
    },
    "es": {  # Spanish
        "QUESTION": "Pregunta:",
        "ANSWER": "Respuesta paso a paso:",
30
31
        "DIRECT": "Respuesta:",
        "REGEX": "La respuesta es (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
32
33
34
    },
    "fr": {  # French
        "QUESTION": "Question :",
35
        # "ANSWER": "Réponse étape par étape :"
lintangsutawika's avatar
lintangsutawika committed
36
        "ANSWER": "R\u00e9ponse \u00e9tape par \u00e9tape :",
37
        # "DIRECT": "Réponse :",
lintangsutawika's avatar
lintangsutawika committed
38
        "DIRECT": "R\u00e9ponse :",
39
40
        # "REGEX": "La réponse est (\\-?[0-9\\.\\,]+)",
        "REGEX": "La r\u00e9ponse est (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
41
42
    },
    "ru": {  # Russian
43
        # "QUESTION": "Задача:",
lintangsutawika's avatar
lintangsutawika committed
44
        "QUESTION": "\u0417\u0430\u0434\u0430\u0447\u0430:",
45
        # "ANSWER": "Пошаговоерешение:",
lintangsutawika's avatar
lintangsutawika committed
46
        "ANSWER": "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:",
lintangsutawika's avatar
lintangsutawika committed
47
        "DIRECT": "Answer:",
48
49
        # "REGEX": "Ответ — (\\-?[0-9\\.\\,]+)",
        "REGEX": "\u041e\u0442\u0432\u0435\u0442 \u2014 (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
50
51
52
53
    },
    "sw": {  # Swahili
        "QUESTION": "Swali:",
        "ANSWER": "Jibu la Hatua kwa Hatua:",
lintangsutawika's avatar
lintangsutawika committed
54
        "DIRECT": "Answer:",
55
        "REGEX": "Jibu ni (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
56
57
    },
    "te": {  # Telugu
58
        # "QUESTION": "ప్రశ్న:",
lintangsutawika's avatar
lintangsutawika committed
59
        "QUESTION": "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:",
60
        # "ANSWER": "దశలవారీగా సమాధానం:",
lintangsutawika's avatar
lintangsutawika committed
61
        "ANSWER": "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:",
lintangsutawika's avatar
lintangsutawika committed
62
        "DIRECT": "Answer:",
63
64
        # "REGEX": "సమాధానం (\\-?[0-9\\.\\,]+)",
        "REGEX": "\u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02 (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
65
66
    },
    "th": {  # Thai
67
        # "QUESTION": "โจทย์:",
lintangsutawika's avatar
lintangsutawika committed
68
        "QUESTION": "\u0e42\u0e08\u0e17\u0e22\u0e4c:",
69
        # "ANSWER": "คำตอบทีละขั้นตอน:",
lintangsutawika's avatar
lintangsutawika committed
70
        "ANSWER": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:",
lintangsutawika's avatar
lintangsutawika committed
71
        "DIRECT": "Answer:",
72
73
        # "REGEX": "คำตอบคือ (\\-?[0-9\\.\\,]+)",
        "REGEX": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e04\u0e37\u0e2d (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
74
75
    },
    "ja": {  # Japanese
76
        # "QUESTION": "問題:",
lintangsutawika's avatar
lintangsutawika committed
77
        "QUESTION": "\u554f\u984c:",
78
        # "ANSWER": "ステップごとの答え:",
lintangsutawika's avatar
lintangsutawika committed
79
        "ANSWER": "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:",
lintangsutawika's avatar
lintangsutawika committed
80
        "DIRECT": "Answer:",
81
82
        # "REGEX": "答えは(\\-?[0-9\\.\\,]+)です。",
        "REGEX": "\u7b54\u3048\u306f(\\-?[0-9\\.\\,]+)\u3067\u3059\u3002",
lintangsutawika's avatar
lintangsutawika committed
83
84
    },
    "zh": {  # Chinese
85
        # "QUESTION": "问题:",
lintangsutawika's avatar
lintangsutawika committed
86
        "QUESTION": "\u95ee\u9898:",
87
        # "ANSWER": "逐步解答:",
lintangsutawika's avatar
lintangsutawika committed
88
        "ANSWER": "\u9010\u6b65\u89e3\u7b54:",
lintangsutawika's avatar
lintangsutawika committed
89
        "DIRECT": "Answer:",
90
91
        # "REGEX": "答案是 (\\-?[0-9\\.\\,]+)。",
        "REGEX": "\u7b54\u6848\u662f (\\-?[0-9\\.\\,]+)\u3002",
lintangsutawika's avatar
lintangsutawika committed
92
93
94
    },
}

lintangsutawika's avatar
lintangsutawika committed
95

lintangsutawika's avatar
lintangsutawika committed
96
97
98
99
def add_regex_pattern(regex_pattern):
    if regex_pattern is None:
        return {}
    return {
lintangsutawika's avatar
lintangsutawika committed
100
101
102
        "filter_list": [
            {
                "name": "get-answer",
103
104
105
106
107
108
109
110
111
                "filter": [
                    {
                        "function": "regex",
                        "regex_pattern": regex_pattern,
                    },
                    {
                        "function": "take_first",
                    },
                ],
lintangsutawika's avatar
lintangsutawika committed
112
113
114
            },
        ],
    }
lintangsutawika's avatar
lintangsutawika committed
115
116
117


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
lintangsutawika's avatar
lintangsutawika committed
118
119
120
121
122
123
124
125
126
127
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    for lang in LANGUAGES.keys():
        try:
            QUESTION = LANGUAGES[lang]["QUESTION"]
lintangsutawika's avatar
lintangsutawika committed
128
129

            yaml_template = "cot_yaml"
130
            filter_list = {}
lintangsutawika's avatar
lintangsutawika committed
131
132
133
134
135
136
137
138
139
            if mode == "direct":
                ANSWER = LANGUAGES[lang]["DIRECT"]
                REGEX = None
                task_name = f"mgsm_{lang}_direct"
                yaml_template = "direct_yaml"
            elif mode == "native-cot":
                ANSWER = LANGUAGES[lang]["ANSWER"]
                REGEX = LANGUAGES[lang]["REGEX"]
                task_name = f"mgsm_{lang}_native-cot"
140
                filter_list = add_regex_pattern(REGEX)
lintangsutawika's avatar
lintangsutawika committed
141
            elif mode == "en-cot":
lintangsutawika's avatar
lintangsutawika committed
142
143
144
145
                ANSWER = LANGUAGES["en"]["ANSWER"]
                REGEX = LANGUAGES["en"]["REGEX"]
                task_name = f"mgsm_{lang}_en-cot"

lintangsutawika's avatar
lintangsutawika committed
146
            file_name = f"{task_name}.yaml"
lintangsutawika's avatar
lintangsutawika committed
147

lintangsutawika's avatar
lintangsutawika committed
148
149
150
151
152
153
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
lintangsutawika's avatar
lintangsutawika committed
154
                        "include": yaml_template,
lintangsutawika's avatar
lintangsutawika committed
155
                        "dataset_name": lang,
lintangsutawika's avatar
lintangsutawika committed
156
                        "task": f"mgsm_{lang}_direct",
lintangsutawika's avatar
lintangsutawika committed
157
158
159
160
161
162
163
164
165
166
167
                        "doc_to_text": f"""{{% if answer is not none %}}"""
                        f"""{{{{question+"\\n{ANSWER}"}}}}"""
                        f"""{{% else %}}"""
                        f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
                        f"""{{% endif %}}""",
                        "doc_to_target": f"""{{% if answer is not none %}}"""
                        f"""{{{{answer[{len(ANSWER)}+1]}}}}"""
                        f"""{{% else %}}"""
                        f"""{{{{answer_number|string}}}}"""
                        f"""{{% endif %}}""",
                        **filter_list,
lintangsutawika's avatar
lintangsutawika committed
168
169
170
                    },
                    f,
                    allow_unicode=True,
171
                    width=float("inf"),
lintangsutawika's avatar
lintangsutawika committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=False,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory to write yaml files to"
    )
lintangsutawika's avatar
lintangsutawika committed
195
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
196
197
198
199
        "--mode",
        default="native-cot",
        choices=["direct", "native-cot", "en-cot"],
        help="Mode of chain-of-thought",
lintangsutawika's avatar
lintangsutawika committed
200
    )
lintangsutawika's avatar
lintangsutawika committed
201
202
    args = parser.parse_args()

lintangsutawika's avatar
lintangsutawika committed
203
    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)
lintangsutawika's avatar
lintangsutawika committed
204
205
206
207


if __name__ == "__main__":
    main()