"tests/t5/test_modeling_t5.py" did not exist on "0558c9cb9b62771d3e3955dc15d34475c2c86995"
Commit 83f95961 authored by lintangsutawika's avatar lintangsutawika
Browse files

add task variants

parent d99e6cf4
# Generated by utils.py
dataset_name: th
doc_to_target: '{% if answer is not none %}{{answer[6+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nAnswer"}}{% else %}{{"โจทย์:
"+question+"\nAnswer"}}{% endif %}'
include: direct_yaml
task: mgsm_th_direct
...@@ -4,5 +4,11 @@ doc_to_target: '{% if answer is not none %}{{answer[17+1]}}{% else %}{{answer_nu ...@@ -4,5 +4,11 @@ doc_to_target: '{% if answer is not none %}{{answer[17+1]}}{% else %}{{answer_nu
endif %}' endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nคำตอบทีละขั้นตอน:"}}{% else doc_to_text: '{% if answer is not none %}{{question+"\nคำตอบทีละขั้นตอน:"}}{% else
%}{{"โจทย์: "+question+"\nคำตอบทีละขั้นตอน:"}}{% endif %}' %}{{"โจทย์: "+question+"\nคำตอบทีละขั้นตอน:"}}{% endif %}'
include: common_template_yaml filter:
task: mgsm_th - function: regex
regex_pattern: The answer is (\-?[0-9\.\,]+)
- function: take_first
filter_list:
- name: get-answer
include: cot_yaml
task: mgsm_th_direct
# Generated by utils.py
dataset_name: zh
doc_to_target: '{% if answer is not none %}{{answer[6+1]}}{% else %}{{answer_number|string}}{%
endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nAnswer"}}{% else %}{{"问题: "+question+"\nAnswer"}}{%
endif %}'
include: direct_yaml
task: mgsm_zh_direct
...@@ -4,5 +4,11 @@ doc_to_target: '{% if answer is not none %}{{answer[5+1]}}{% else %}{{answer_num ...@@ -4,5 +4,11 @@ doc_to_target: '{% if answer is not none %}{{answer[5+1]}}{% else %}{{answer_num
endif %}' endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\n逐步解答:"}}{% else %}{{"问题: "+question+"\n逐步解答:"}}{% doc_to_text: '{% if answer is not none %}{{question+"\n逐步解答:"}}{% else %}{{"问题: "+question+"\n逐步解答:"}}{%
endif %}' endif %}'
include: common_template_yaml filter:
task: mgsm_zh - function: regex
regex_pattern: The answer is (\-?[0-9\.\,]+)
- function: take_first
filter_list:
- name: get-answer
include: cot_yaml
task: mgsm_zh_direct
...@@ -6,51 +6,94 @@ LANGUAGES = { ...@@ -6,51 +6,94 @@ LANGUAGES = {
"bn": { # Bengali "bn": { # Bengali
"QUESTION": "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:", "QUESTION": "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:",
"ANSWER": "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:", "ANSWER": "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"de": { # German "de": { # German
"QUESTION": "Frage:", "QUESTION": "Frage:",
"ANSWER": "Schritt-f\u00fcr-Schritt-Antwort:", "ANSWER": "Schritt-f\u00fcr-Schritt-Antwort:",
"DIRECT": "Antwort:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"en": { # English "en": { # English
"QUESTION": "Question:", "QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:", "ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"es": { # Spanish "es": { # Spanish
"QUESTION": "Pregunta:", "QUESTION": "Pregunta:",
"ANSWER": "Respuesta paso a paso:", "ANSWER": "Respuesta paso a paso:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"fr": { # French "fr": { # French
"QUESTION": "Question :", "QUESTION": "Question :",
"ANSWER": "R\u00e9ponse \u00e9tape par \u00e9tape :", "ANSWER": "R\u00e9ponse \u00e9tape par \u00e9tape :",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"ru": { # Russian "ru": { # Russian
"QUESTION": "\u0417\u0430\u0434\u0430\u0447\u0430:", "QUESTION": "\u0417\u0430\u0434\u0430\u0447\u0430:",
"ANSWER": "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:", "ANSWER": "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"sw": { # Swahili "sw": { # Swahili
"QUESTION": "Swali:", "QUESTION": "Swali:",
"ANSWER": "Jibu la Hatua kwa Hatua:", "ANSWER": "Jibu la Hatua kwa Hatua:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"te": { # Telugu "te": { # Telugu
"QUESTION": "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:", "QUESTION": "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:",
"ANSWER": "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:", "ANSWER": "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"th": { # Thai "th": { # Thai
"QUESTION": "\u0e42\u0e08\u0e17\u0e22\u0e4c:", "QUESTION": "\u0e42\u0e08\u0e17\u0e22\u0e4c:",
"ANSWER": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:", "ANSWER": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"ja": { # Japanese "ja": { # Japanese
"QUESTION": "\u554f\u984c:", "QUESTION": "\u554f\u984c:",
"ANSWER": "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:", "ANSWER": "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
"zh": { # Chinese "zh": { # Chinese
"QUESTION": "\u95ee\u9898:", "QUESTION": "\u95ee\u9898:",
"ANSWER": "\u9010\u6b65\u89e3\u7b54:", "ANSWER": "\u9010\u6b65\u89e3\u7b54:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
}, },
} }
def add_regex_pattern(regex_pattern):
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None: if regex_pattern is None:
return {}
return {
"filter_list": [
{
"name": "get-answer",
},
],
"filter": [
{
"function": "regex",
"regex_pattern": regex_pattern,
},
{
"function": "take_first",
},
],
}
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
""" """
Generate a yaml file for each language. Generate a yaml file for each language.
...@@ -59,19 +102,36 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None: ...@@ -59,19 +102,36 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
""" """
err = [] err = []
for lang in LANGUAGES.keys(): for lang in LANGUAGES.keys():
file_name = f"mgsm_{lang}.yaml"
try: try:
QUESTION = LANGUAGES[lang]["QUESTION"] QUESTION = LANGUAGES[lang]["QUESTION"]
yaml_template = "cot_yaml"
if mode == "direct":
ANSWER = LANGUAGES[lang]["DIRECT"]
REGEX = None
task_name = f"mgsm_{lang}_direct"
yaml_template = "direct_yaml"
elif mode == "native-cot":
ANSWER = LANGUAGES[lang]["ANSWER"] ANSWER = LANGUAGES[lang]["ANSWER"]
REGEX = LANGUAGES[lang]["REGEX"]
task_name = f"mgsm_{lang}_native-cot"
elif model == "en-cot":
ANSWER = LANGUAGES["en"]["ANSWER"]
REGEX = LANGUAGES["en"]["REGEX"]
task_name = f"mgsm_{lang}_en-cot"
file_name = f"{file_name}.yaml"
filter_list = add_regex_pattern(REGEX)
with open( with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8" f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
) as f: ) as f:
f.write("# Generated by utils.py\n") f.write("# Generated by utils.py\n")
yaml.dump( yaml.dump(
{ {
"include": "common_template_yaml", "include": yaml_template,
"dataset_name": lang, "dataset_name": lang,
"task": f"mgsm_{lang}", "task": f"mgsm_{lang}_direct",
"doc_to_text": f"""{{% if answer is not none %}}""" \ "doc_to_text": f"""{{% if answer is not none %}}""" \
f"""{{{{question+"\\n{ANSWER}"}}}}""" \ f"""{{{{question+"\\n{ANSWER}"}}}}""" \
f"""{{% else %}}""" \ f"""{{% else %}}""" \
...@@ -82,6 +142,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None: ...@@ -82,6 +142,7 @@ def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
f"""{{% else %}}""" \ f"""{{% else %}}""" \
f"""{{{{answer_number|string}}}}""" \ f"""{{{{answer_number|string}}}}""" \
f"""{{% endif %}}""", f"""{{% endif %}}""",
**filter_list
}, },
f, f,
allow_unicode=True, allow_unicode=True,
...@@ -108,9 +169,12 @@ def main() -> None: ...@@ -108,9 +169,12 @@ def main() -> None:
parser.add_argument( parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to" "--output-dir", default=".", help="Directory to write yaml files to"
) )
parser.add_argument(
"--mode", default="native-cot", choices=["direct", "native-cot", "en-cot"], help="Mode of chain-of-thought"
)
args = parser.parse_args() args = parser.parse_args()
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite) gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment