gen_utils.py 3.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import os

import yaml


class FunctionTag:
    def __init__(self, value):
        self.value = value


def prompt_func(mode, lang):
    prompt_map = {
        "prompt_1": "Please restore the missing diacritics in the following sentence: {{text}}. Return output sentence only",
        "prompt_2": "Given a sentence without diacritics, add the appropriate diacritics to make it grammatically "
        "and semantically correct. \nSentence: {{text}}. Return output sentence only",
        "prompt_3": f"This text is in {lang}. Restore all diacritical marks to their proper places in the "
        "following sentence: {{text}}. Return output sentence only",
        "prompt_4": f"You are a linguist specializing in diacritical marks for {lang}. "
        f"Add the appropriate diacritics to this {lang} sentence: "
        "{{text}}. Return output sentence only",
        "prompt_5": f"You are a linguist specializing in diacritical marks for {lang}. Diacritics are essential for "
        f"proper pronunciation and meaning in {lang}. You are tasked with converting {lang} sentences  "
        "without diacritics into their correctly accented forms. Here's the input: {{text}}. "
        "Return output sentence only",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {
        "fon": "Fon",
        "bbj": "Gbomala",
        "ibo": "Igbo",
        "wol": "Wolof",
        "yor": "Yoruba",
    }

    for lang in languages.keys():
        try:
            file_name = f"afridiacritics_{lang}.yaml"
            task_name = f"afridiacritics_{lang}_{mode}"
            yaml_template = "afridiacritics_yaml"
            yaml_details = {
                "include": yaml_template,
                "task": task_name,
                "dataset_name": lang,
                "doc_to_text": prompt_func(mode, languages[lang]),
            }
            os.makedirs(f"{output_dir}/{mode}", exist_ok=True)
            with open(
                f"{output_dir}/{mode}/{file_name}",
                "w" if overwrite else "x",
                encoding="utf8",
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    yaml_details,
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )
    parser.add_argument(
        "--mode",
        default="prompt_1",
        choices=["prompt_1", "prompt_2", "prompt_3", "prompt_4", "prompt_5"],
        help="Prompt number",
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)


if __name__ == "__main__":
    main()