utils.py 4.93 KB
Newer Older
JessicaOjo's avatar
JessicaOjo committed
1
2
import argparse

Israel Abebe Azime's avatar
Israel Abebe Azime committed
3
import yaml
JessicaOjo's avatar
JessicaOjo committed
4
5
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']

Israel Abebe Azime's avatar
Israel Abebe Azime committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
LANGUAGES = {}

for lang in languages:
    LANGUAGES[lang] = {  # English
        "QUESTION": "Question:",
        "ANSWER": "Step-by-Step Answer:",
        "DIRECT": "Answer:",
        "REGEX": "The answer is (\\-?[0-9\\.\\,]+)"} 


def add_regex_pattern(regex_pattern):
    if regex_pattern is None:
        return {}
    return {
        "filter_list": [
            {
                "name": "strict-match",
                "filter": [
                    {
                        "function": "regex",
                        "regex_pattern": f"""{regex_pattern}""",
                    },
                    {
                        "function": "take_first",
                    },
                ],
            },
            {
                "name": "flexible-extract",
                "filter": [
                    {
                        "function": "regex",
                        "regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""",
                        "group_select": -1,
                    },
                    {
                        "function": "take_first",
                    },
                ],
            },
        ],
    }
JessicaOjo's avatar
JessicaOjo committed
48
49
50
51
52
53
54
55
56
57


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
Israel Abebe Azime's avatar
Israel Abebe Azime committed
58
    for lang in LANGUAGES.keys():
JessicaOjo's avatar
JessicaOjo committed
59
        try:
Israel Abebe Azime's avatar
Israel Abebe Azime committed
60
61
62
63
64
            QUESTION = LANGUAGES[lang]["QUESTION"]

            yaml_template = "cot_yaml"
            filter_list = {}
            DELIMITER = None
JessicaOjo's avatar
JessicaOjo committed
65
            if mode == "direct":
Israel Abebe Azime's avatar
Israel Abebe Azime committed
66
67
                ANSWER = LANGUAGES[lang]["DIRECT"]
                REGEX = None
JessicaOjo's avatar
JessicaOjo committed
68
                task_name = f"afrimgsm_direct_{lang}"
Israel Abebe Azime's avatar
Israel Abebe Azime committed
69
                yaml_template = "direct_yaml"
JessicaOjo's avatar
JessicaOjo committed
70
            elif mode == "native-cot":
Israel Abebe Azime's avatar
Israel Abebe Azime committed
71
72
                ANSWER = LANGUAGES[lang]["ANSWER"]
                REGEX = LANGUAGES[lang]["REGEX"]
JessicaOjo's avatar
JessicaOjo committed
73
                task_name = f"afrimgsm_native_cot_{lang}"
Israel Abebe Azime's avatar
Israel Abebe Azime committed
74
75
                filter_list = add_regex_pattern(REGEX)
                DELIMITER = "" if lang in ["zh", "ja"] else None
JessicaOjo's avatar
JessicaOjo committed
76
            elif mode == "en-cot":
Israel Abebe Azime's avatar
Israel Abebe Azime committed
77
78
                ANSWER = LANGUAGES["en"]["ANSWER"]
                REGEX = LANGUAGES["en"]["REGEX"]
JessicaOjo's avatar
JessicaOjo committed
79
80
81
                task_name = f"afrimgsm_en_cot_{lang}"

            file_name = f"{task_name}.yaml"
Israel Abebe Azime's avatar
Israel Abebe Azime committed
82
            ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1
JessicaOjo's avatar
JessicaOjo committed
83
84
85
86
87
88
89
90
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
                        "include": yaml_template,
                        "dataset_name": lang,
Israel Abebe Azime's avatar
Israel Abebe Azime committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
                        "task": f"{task_name}",
                        "doc_to_text": f"""{{% if answer is not none %}}"""
                        f"""{{{{question+"\\n{ANSWER}"}}}}"""
                        f"""{{% else %}}"""
                        f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
                        f"""{{% endif %}}""",
                        "doc_to_target": f"""{{% if answer is not none %}}"""
                        f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}"""
                        f"""{{% else %}}"""
                        f"""{{{{answer_number|string}}}}"""
                        f"""{{% endif %}}""",
                        **filter_list,
                        "generation_kwargs": {
                            "until": [QUESTION, "</s>", "<|im_end|>"],
                            "do_sample": False,
                        },
                        **({"target_delimiter": DELIMITER} if DELIMITER else {}),
JessicaOjo's avatar
JessicaOjo committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
                    },
                    f,
                    allow_unicode=True,
                    width=float("inf"),
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
Israel Abebe Azime's avatar
Israel Abebe Azime committed
128
        default=False,
JessicaOjo's avatar
JessicaOjo committed
129
130
131
132
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
Israel Abebe Azime's avatar
Israel Abebe Azime committed
133
        "--output-dir", default=".", help="Directory to write yaml files to"
JessicaOjo's avatar
JessicaOjo committed
134
135
136
    )
    parser.add_argument(
        "--mode",
Israel Abebe Azime's avatar
Israel Abebe Azime committed
137
        default="native-cot",
JessicaOjo's avatar
JessicaOjo committed
138
139
140
141
142
143
144
145
146
147
        choices=["direct", "native-cot", "en-cot"],
        help="Mode of chain-of-thought",
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)


if __name__ == "__main__":
    main()