utils.py 3.02 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
from typing import Dict, List

import yaml


# Different languages that are part of xnli.
# These correspond to dataset names (Subsets) on HuggingFace.
# A yaml file is generated by this script for each language.

LANGUAGES = {
    "de": {  # German
        "QUESTION_WORD": "richtig",
        "YES": "Ja",
        "NO": "Nein",
    },
    "en": {  # English
        "QUESTION_WORD": "right",
        "YES": "Yes",
        "NO": "No",
    },
    "es": {  # Spanish
        "QUESTION_WORD": "verdad",
        "YES": "Sí",
        "NO": "No",
    },
    "fr": {  # French
        "QUESTION_WORD": "n'est-ce pas",
        "YES": "Oui",
        "NO": "No",
    },
    "ja": {  # Japanese
        "QUESTION_WORD": "ですね",
        "YES": "はい",
        "NO": "いいえ",
    },
    "ko": {  # Korean
        "QUESTION_WORD": "맞죠",
        "YES": "예",
        "NO": "아니요",
    },
    "zh": {  # Chinese
        "QUESTION_WORD": "对吧",
        "YES": "是",
        "NO": "不是",
    },
}


def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    for lang in LANGUAGES.keys():
        file_name = f"paws_{lang}.yaml"
        try:
            QUESTION_WORD = LANGUAGES[lang]["QUESTION_WORD"]
            YES = LANGUAGES[lang]["YES"]
            NO = LANGUAGES[lang]["NO"]
            with open(
                f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    {
                        "include": "pawsx_template_yaml",
                        "dataset_name": lang,
lintangsutawika's avatar
update  
lintangsutawika committed
72
                        "task": f"paws_{lang}",
lintangsutawika's avatar
lintangsutawika committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                        "doc_to_text": "",
                        "doc_to_choice": f"{{{{["
                        f"""sentence1+\", {QUESTION_WORD}? {YES}, \"+sentence2,"""
                        f""" sentence1+\", {QUESTION_WORD}? {NO}, \"+sentence2"""
                        f"]}}}}",
                    },
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=False,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory to write yaml files to"
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)


if __name__ == "__main__":
    main()