utils.py 4.06 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import yaml
lintangsutawika's avatar
lintangsutawika committed
2
import argparse
lintangsutawika's avatar
lintangsutawika committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74


LANGUAGES = {
    "bn": {  # Bengali
        "QUESTION": "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:",
        "ANSWER": "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:",
    },
    "de": {  # German
        "QUESTION": "Frage:",
        "ANSWER": "Schritt-f\u00fcr-Schritt-Antwort:",
    },
    "en": {  # English
        "QUESTION": "Question:",
        "ANSWER": "Step-by-Step Answer:",
    },
    "es": {  # Spanish
        "QUESTION": "Pregunta:",
        "ANSWER": "Respuesta paso a paso:",
    },
    "fr": {  # French
        "QUESTION": "Question :",
        "ANSWER": "R\u00e9ponse \u00e9tape par \u00e9tape :",
    },
    "ru": {  # Russian
        "QUESTION": "\u0417\u0430\u0434\u0430\u0447\u0430:",
        "ANSWER": "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:",
    },
    "sw": {  # Swahili
        "QUESTION": "Swali:",
        "ANSWER": "Jibu la Hatua kwa Hatua:",
    },
    "te": {  # Telugu
        "QUESTION": "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:",
        "ANSWER": "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:",
    },
    "th": {  # Thai
        "QUESTION": "\u0e42\u0e08\u0e17\u0e22\u0e4c:",
        "ANSWER": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:",
    },
    "ja": {  # Japanese
        "QUESTION": "\u554f\u984c:",
        "ANSWER": "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:",
    },
    "zh": {  # Chinese
        "QUESTION": "\u95ee\u9898:",
        "ANSWER": "\u9010\u6b65\u89e3\u7b54:",
    },
}


def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    for lang in LANGUAGES.keys():
        file_name = f"mgsm_{lang}.yaml"
        try:
            QUESTION = LANGUAGES[lang]["QUESTION"]
            ANSWER = LANGUAGES[lang]["ANSWER"]
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
                        "include": "common_template_yaml",
                        "dataset_name": lang,
                        "task": f"mgsm_{lang}",
lintangsutawika's avatar
lintangsutawika committed
75
76
77
78
79
80
81
82
83
84
                        "doc_to_text": f"""{{% if answer is not none %}}""" \
                                       f"""{{{{question+"\\n{ANSWER}"}}}}""" \
                                       f"""{{% else %}}""" \
                                       f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}""" \
                                       f"""{{% endif %}}""",
                        "doc_to_target": f"""{{% if answer is not none %}}""" \
                                         f"""{{{{answer[{len(ANSWER)}+1]}}}}""" \
                                         f"""{{% else %}}""" \
                                         f"""{{{{answer_number|string}}}}""" \
                                         f"""{{% endif %}}""",
lintangsutawika's avatar
lintangsutawika committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
                    },
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=False,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory to write yaml files to"
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)


if __name__ == "__main__":
    main()