utils.py 6.07 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import yaml
lintangsutawika's avatar
lintangsutawika committed
2
import argparse
lintangsutawika's avatar
lintangsutawika committed
3
4
5
6
7
8


LANGUAGES = {
    "bn": {  # Bengali
        "QUESTION": "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:",
        "ANSWER": "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:",
lintangsutawika's avatar
lintangsutawika committed
9
10
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
11
12
13
14
    },
    "de": {  # German
        "QUESTION": "Frage:",
        "ANSWER": "Schritt-f\u00fcr-Schritt-Antwort:",
lintangsutawika's avatar
lintangsutawika committed
15
16
        "DIRECT": "Antwort:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
17
18
19
20
    },
    "en": {  # English
        "QUESTION": "Question:",
        "ANSWER": "Step-by-Step Answer:",
lintangsutawika's avatar
lintangsutawika committed
21
22
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
23
24
25
26
    },
    "es": {  # Spanish
        "QUESTION": "Pregunta:",
        "ANSWER": "Respuesta paso a paso:",
lintangsutawika's avatar
lintangsutawika committed
27
28
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
29
30
31
32
    },
    "fr": {  # French
        "QUESTION": "Question :",
        "ANSWER": "R\u00e9ponse \u00e9tape par \u00e9tape :",
lintangsutawika's avatar
lintangsutawika committed
33
34
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
35
36
37
38
    },
    "ru": {  # Russian
        "QUESTION": "\u0417\u0430\u0434\u0430\u0447\u0430:",
        "ANSWER": "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:",
lintangsutawika's avatar
lintangsutawika committed
39
40
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
41
42
43
44
    },
    "sw": {  # Swahili
        "QUESTION": "Swali:",
        "ANSWER": "Jibu la Hatua kwa Hatua:",
lintangsutawika's avatar
lintangsutawika committed
45
46
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
47
48
49
50
    },
    "te": {  # Telugu
        "QUESTION": "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:",
        "ANSWER": "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:",
lintangsutawika's avatar
lintangsutawika committed
51
52
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
53
54
55
56
    },
    "th": {  # Thai
        "QUESTION": "\u0e42\u0e08\u0e17\u0e22\u0e4c:",
        "ANSWER": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:",
lintangsutawika's avatar
lintangsutawika committed
57
58
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
59
60
61
62
    },
    "ja": {  # Japanese
        "QUESTION": "\u554f\u984c:",
        "ANSWER": "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:",
lintangsutawika's avatar
lintangsutawika committed
63
64
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
65
66
67
68
    },
    "zh": {  # Chinese
        "QUESTION": "\u95ee\u9898:",
        "ANSWER": "\u9010\u6b65\u89e3\u7b54:",
lintangsutawika's avatar
lintangsutawika committed
69
70
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
lintangsutawika's avatar
lintangsutawika committed
71
72
73
    },
}

lintangsutawika's avatar
lintangsutawika committed
74

lintangsutawika's avatar
lintangsutawika committed
75
def add_regex_pattern(regex_pattern):
lintangsutawika's avatar
lintangsutawika committed
76

lintangsutawika's avatar
lintangsutawika committed
77
78
79
    if regex_pattern is None:
        return {}
    return {
lintangsutawika's avatar
lintangsutawika committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        "filter_list": [
            {
                "name": "get-answer",
            },
        ],
        "filter": [
            {
                "function": "regex",
                "regex_pattern": regex_pattern,
            },
            {
                "function": "take_first",
            },
        ],
    }
lintangsutawika's avatar
lintangsutawika committed
95
96
97


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
lintangsutawika's avatar
lintangsutawika committed
98
99
100
101
102
103
104
105
106
107
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    for lang in LANGUAGES.keys():
        try:
            QUESTION = LANGUAGES[lang]["QUESTION"]
lintangsutawika's avatar
lintangsutawika committed
108
109
110
111
112
113
114
115
116
117
118

            yaml_template = "cot_yaml"
            if mode == "direct":
                ANSWER = LANGUAGES[lang]["DIRECT"]
                REGEX = None
                task_name = f"mgsm_{lang}_direct"
                yaml_template = "direct_yaml"
            elif mode == "native-cot":
                ANSWER = LANGUAGES[lang]["ANSWER"]
                REGEX = LANGUAGES[lang]["REGEX"]
                task_name = f"mgsm_{lang}_native-cot"
lintangsutawika's avatar
lintangsutawika committed
119
            elif mode == "en-cot":
lintangsutawika's avatar
lintangsutawika committed
120
121
122
123
                ANSWER = LANGUAGES["en"]["ANSWER"]
                REGEX = LANGUAGES["en"]["REGEX"]
                task_name = f"mgsm_{lang}_en-cot"

lintangsutawika's avatar
lintangsutawika committed
124
            file_name = f"{task_name}.yaml"
lintangsutawika's avatar
lintangsutawika committed
125
126
            filter_list = add_regex_pattern(REGEX)

lintangsutawika's avatar
lintangsutawika committed
127
128
129
130
131
132
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
lintangsutawika's avatar
lintangsutawika committed
133
                        "include": yaml_template,
lintangsutawika's avatar
lintangsutawika committed
134
                        "dataset_name": lang,
lintangsutawika's avatar
lintangsutawika committed
135
                        "task": f"mgsm_{lang}_direct",
lintangsutawika's avatar
lintangsutawika committed
136
137
138
139
140
141
142
143
144
145
146
                        "doc_to_text": f"""{{% if answer is not none %}}"""
                        f"""{{{{question+"\\n{ANSWER}"}}}}"""
                        f"""{{% else %}}"""
                        f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
                        f"""{{% endif %}}""",
                        "doc_to_target": f"""{{% if answer is not none %}}"""
                        f"""{{{{answer[{len(ANSWER)}+1]}}}}"""
                        f"""{{% else %}}"""
                        f"""{{{{answer_number|string}}}}"""
                        f"""{{% endif %}}""",
                        **filter_list,
lintangsutawika's avatar
lintangsutawika committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                    },
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=False,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory to write yaml files to"
    )
lintangsutawika's avatar
lintangsutawika committed
173
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
174
175
176
177
        "--mode",
        default="native-cot",
        choices=["direct", "native-cot", "en-cot"],
        help="Mode of chain-of-thought",
lintangsutawika's avatar
lintangsutawika committed
178
    )
lintangsutawika's avatar
lintangsutawika committed
179
180
    args = parser.parse_args()

lintangsutawika's avatar
lintangsutawika committed
181
    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)
lintangsutawika's avatar
lintangsutawika committed
182
183
184
185


if __name__ == "__main__":
    main()