utils.py 3.79 KB
Newer Older
haileyschoelkopf's avatar
haileyschoelkopf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
from typing import Dict, List

import yaml

import sacrebleu

try:
    import pycountry
except ModuleNotFoundError:
    raise Exception(
        "`pycountry` is required for generating translation task prompt templates. \
please install pycountry via pip install lm-eval[multilingua] or pip install -e .[multilingual]",
    )


# Different translation benchmarks included in the library. Mostly WMT.
# These correspond to dataset names (subsets) on HuggingFace for each dataset.
# A yaml file is generated by this script for each language pair.

gpt3_translation_benchmarks = {
    "wmt14": ["fr-en"],  # ["en-fr", "fr-en"],  # French
    "wmt16": [
        "ro-en",
        "de-en",
    ],  # ["en-ro", "ro-en", "de-en", "en-de"],  # German, Romanian
}

# 28 total
LANGUAGES = {
    **gpt3_translation_benchmarks,
    # "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
    "iwslt2017": ["en-ar", "ar-en"],  # Arabic
}


def code_to_language(code):
    # key is alpha_2 or alpha_3 depending on the code length
    language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
    return language_tuple.name


def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    for lang in LANGUAGES.keys():
        for lang_pair in LANGUAGES[lang]:
            file_name = f"{lang}_{lang_pair}.yaml"
            try:
                src_lang, _, tgt_lang = lang_pair.partition("-")
                source, target = code_to_language(src_lang), code_to_language(tgt_lang)

                groups = ["greedy_until", "translation", lang]
                if lang in gpt3_translation_benchmarks.keys():
                    groups += ["gpt3_translation_benchmarks"]

                with open(
                    f"{output_dir}/{file_name}",
                    "w" if overwrite else "x",
                    encoding="utf8",
                ) as f:
                    f.write("# Generated by utils.py\n")
                    yaml.dump(
                        {
                            "include": "wmt_common_yaml",
                            "group": groups,
                            "dataset_path": lang,
                            "dataset_name": lang_pair
                            if not (lang == "iwslt2017")
                            else "iwslt2017-" + lang_pair,
                            "task": f"{lang}-{lang_pair}",
                            "doc_to_text": f"{source} phrase: "
                            + "{{translation["
                            + f'"{src_lang}"'
                            + "]}}\n"
                            + f"{target} phrase:",
                            "doc_to_target": " {{"
                            + "translation["
                            + f'"{tgt_lang}"]'
                            + "}}",
                        },
                        f,
                    )
            except FileExistsError:
                err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=False,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory to write yaml files to"
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)


if __name__ == "__main__":
    main()