utils.py 2.61 KB
Newer Older
JessicaOjo's avatar
JessicaOjo committed
1
import argparse
Israel Abebe Azime's avatar
Israel Abebe Azime committed
2
3
import yaml

4
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
Israel Abebe Azime's avatar
Israel Abebe Azime committed
5

6
7
8
9
10
configs = {
    "QUESTION": "Question:",
    "ANSWER": "Step-by-Step Answer:",
    "DIRECT": "Answer:",
    "REGEX": "The answer is (\\-?[0-9\\.\\,]+)"}
JessicaOjo's avatar
JessicaOjo committed
11
12
13
14
15
16
17
18
19
20


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
21
    for lang in languages:
JessicaOjo's avatar
JessicaOjo committed
22
23
24
        try:
            if mode == "direct":
                task_name = f"afrimgsm_direct_{lang}"
25
                yaml_template = "afrimgsm_common_yaml"
JessicaOjo's avatar
JessicaOjo committed
26
27
            elif mode == "native-cot":
                task_name = f"afrimgsm_native_cot_{lang}"
28
                yaml_template = "afrimgsm_common_yaml"
JessicaOjo's avatar
JessicaOjo committed
29
30
            elif mode == "en-cot":
                task_name = f"afrimgsm_en_cot_{lang}"
31
                yaml_template = "afrimgsm_common_yaml"
JessicaOjo's avatar
JessicaOjo committed
32
33
34
35
36
37
38
39
40
41

            file_name = f"{task_name}.yaml"
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
                        "include": yaml_template,
                        "dataset_name": lang,
42
                        "task": f"{task_name}"
JessicaOjo's avatar
JessicaOjo committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
                    },
                    f,
                    allow_unicode=True,
                    width=float("inf"),
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
63
        default=True,
JessicaOjo's avatar
JessicaOjo committed
64
65
66
67
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
68
        "--output-dir", default="./direct", help="Directory to write yaml files to"
JessicaOjo's avatar
JessicaOjo committed
69
70
71
    )
    parser.add_argument(
        "--mode",
Israel Abebe Azime's avatar
Israel Abebe Azime committed
72
        default="native-cot",
Israel Abebe Azime's avatar
Israel Abebe Azime committed
73
        choices=["direct","direct-native", "native-cot", "en-cot","translate-direct"],
JessicaOjo's avatar
JessicaOjo committed
74
75
76
77
78
79
80
81
82
        help="Mode of chain-of-thought",
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)


if __name__ == "__main__":
    main()