Unverified Commit 1b97e487 authored by Jess's avatar Jess Committed by GitHub
Browse files

Merge pull request #28 from JessicaOjo/africamgsm

revert xnli to multiple_choice
parents 692510cc 4583bb42
import argparse import argparse
import yaml import yaml
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa'] languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor',
'hau', 'sot', 'swa']
languages_REGEX = {"eng": "The answer is (\\-?[0-9\\.\\,]+)",
"amh": "መልሱ (\\-?[0-9\\.\\,]+)",
"ibo": "Azịza ya bụ (\\-?[0-9\\.\\,]+)",
'fra': "La réponse est(\\-?[0-9\\.\\,]+)",
'sna': "Mhinduro kumubvunzo ndi (\\-?[0-9\\.\\,]+)",
'lin': "Eyano ezali (\\-?[0-9\\.\\,]+)",
'wol': "Tontu li (\\-?[0-9\\.\\,]+)",
'ewe': "ŋuɖoɖoae nye (\\-?[0-9\\.\\,]+)",
'lug': "Ansa eri (\\-?[0-9\\.\\,]+)",
'xho': "Impendulo ngu (\\-?[0-9\\.\\,]+)",
'kin': "Igisubizo ni (\\-?[0-9\\.\\,]+)",
'twi': "Ne nnyiano yɛ (\\-?[0-9\\.\\,]+)",
'zul': "Impendulo ithi (\\-?[0-9\\.\\,]+)",
'orm': "Deebiin isaa (\\-?[0-9\\.\\,]+)",
'yor': "Ìdáhùn náà ni (\\-?[0-9\\.\\,]+)",
'hau': "Amsar ita ce (\\-?[0-9\\.\\,]+)",
'sot': "Karabo ke (\\-?[0-9\\.\\,]+)",
'swa': "Jibu ni (\\-?[0-9\\.\\,]+)",
}
LANGUAGES = {}
for lang in languages:
if lang == 'amh':
LANGUAGES[lang] = { # English
"QUESTION": "ጥያቄ:",
"ANSWER": "በቅደም ተከተል መልስ:",
"DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
elif lang == 'yor':
LANGUAGES[lang] = { # English
"QUESTION": "Ìbéèrè:",
"ANSWER": "Ìdáhùn lẹ́sẹsẹ:",
"DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
configs = { else:
"QUESTION": "Question:", LANGUAGES[lang] = { # English
"ANSWER": "Step-by-Step Answer:", "QUESTION": "Question:",
"DIRECT": "Answer:", "ANSWER": "Step-by-Step Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)"} "DIRECT": "Answer:",
"REGEX": languages_REGEX[lang]}
def add_regex_pattern(regex_pattern):
if regex_pattern is None:
return {}
return {
"filter_list": [
{
"name": "strict-match",
"filter": [
{
"function": "regex",
"regex_pattern": f"""{regex_pattern}""",
},
{
"function": "take_first",
},
],
},
{
"name": "flexible-extract",
"filter": [
{
"function": "regex",
"regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""",
"group_select": -1,
},
{
"function": "take_first",
},
],
},
],
}
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
...@@ -18,28 +91,70 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: ...@@ -18,28 +91,70 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
:param overwrite: Whether to overwrite files if they already exist. :param overwrite: Whether to overwrite files if they already exist.
""" """
err = [] err = []
for lang in languages: for lang in LANGUAGES.keys():
try: try:
yaml_template = "cot_yaml"
filter_list = {}
DELIMITER = None
if mode == "direct": if mode == "direct":
ANSWER = LANGUAGES['eng']["DIRECT"]
QUESTION = LANGUAGES['eng']["QUESTION"]
REGEX = None
task_name = f"afrimgsm_direct_{lang}" task_name = f"afrimgsm_direct_{lang}"
yaml_template = "afrimgsm_common_yaml" yaml_template = "direct_yaml"
if mode == "direct-native":
ANSWER = LANGUAGES[lang]["DIRECT"]
QUESTION = LANGUAGES[lang]["QUESTION"]
REGEX = None
task_name = f"afrimgsm_direct_native_{lang}"
yaml_template = "direct_native_yaml"
elif mode == "native-cot": elif mode == "native-cot":
ANSWER = LANGUAGES[lang]["ANSWER"]
REGEX = LANGUAGES[lang]["REGEX"]
QUESTION = LANGUAGES[lang]["QUESTION"]
task_name = f"afrimgsm_native_cot_{lang}" task_name = f"afrimgsm_native_cot_{lang}"
yaml_template = "afrimgsm_common_yaml" filter_list = add_regex_pattern(REGEX)
DELIMITER = "" if lang in ["zh", "ja"] else None
elif mode == "en-cot": elif mode == "en-cot":
ANSWER = LANGUAGES["eng"]["ANSWER"]
REGEX = LANGUAGES["eng"]["REGEX"]
QUESTION = LANGUAGES["eng"]["QUESTION"]
task_name = f"afrimgsm_en_cot_{lang}" task_name = f"afrimgsm_en_cot_{lang}"
yaml_template = "afrimgsm_common_yaml" elif mode == "translate-direct":
ANSWER = LANGUAGES['eng']["DIRECT"]
QUESTION = LANGUAGES['eng']["QUESTION"]
REGEX = None
task_name = f"translate_afrimgsm_direct_{lang}"
yaml_template = "translate_direct_yaml"
file_name = f"{task_name}.yaml" file_name = f"{task_name}.yaml"
ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1
with open( with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8" f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
) as f: ) as f:
f.write("# Generated by utils.py\n") f.write("# Generated by utils.py\n")
yaml.dump( yaml.dump(
{ {
"include": yaml_template, "include": yaml_template,
"dataset_name": lang, "dataset_name": lang,
"task": f"{task_name}" "task": f"{task_name}",
"doc_to_text": f"""{{% if answer is not none %}}"""
f"""{{{{question+"\\n{ANSWER}"}}}}"""
f"""{{% else %}}"""
f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
f"""{{% endif %}}""",
"doc_to_target": f"""{{% if answer is not none %}}"""
f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}"""
f"""{{% else %}}"""
f"""{{{{answer_number|string}}}}"""
f"""{{% endif %}}""",
**filter_list,
"generation_kwargs": {
"until": [QUESTION, "</s>", "<|im_end|>"],
"do_sample": False,
},
**({"target_delimiter": DELIMITER} if DELIMITER else {}),
}, },
f, f,
allow_unicode=True, allow_unicode=True,
...@@ -60,17 +175,17 @@ def main() -> None: ...@@ -60,17 +175,17 @@ def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--overwrite", "--overwrite",
default=True, default=False,
action="store_true", action="store_true",
help="Overwrite files if they already exist", help="Overwrite files if they already exist",
) )
parser.add_argument( parser.add_argument(
"--output-dir", default="./direct", help="Directory to write yaml files to" "--output-dir", default=".", help="Directory to write yaml files to"
) )
parser.add_argument( parser.add_argument(
"--mode", "--mode",
default="native-cot", default="native-cot",
choices=["direct","direct-native", "native-cot", "en-cot","translate-direct"], choices=["direct", "direct-native", "native-cot", "en-cot", "translate-direct"],
help="Mode of chain-of-thought", help="Mode of chain-of-thought",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -79,4 +194,4 @@ def main() -> None: ...@@ -79,4 +194,4 @@ def main() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -4,7 +4,7 @@ group: ...@@ -4,7 +4,7 @@ group:
- afrixnli-manual - afrixnli-manual
dataset_path: masakhane/afrixnli dataset_path: masakhane/afrixnli
dataset_name: null dataset_name: null
output_type: generate_until output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
fewshot_split: validation fewshot_split: validation
...@@ -16,14 +16,6 @@ doc_to_choice: ...@@ -16,14 +16,6 @@ doc_to_choice:
- "contradiction" - "contradiction"
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: premise doc_to_decontamination_query: premise
filter_list:
- name: "verbalizer_extract"
filter:
- function: verbalizer
verbalizer_dict: {
"entailment": ['encouragement', 'entitlement', 'entails', 'entailed', 'entailment'],
"contradiction": ['contradictory', 'contradicts', 'contradiction'],
"neutral": ['neutral']}
metric_list: metric_list:
- metric: f1 - metric: f1
aggregation: !function utils.weighted_f1_score aggregation: !function utils.weighted_f1_score
...@@ -32,7 +24,7 @@ metric_list: ...@@ -32,7 +24,7 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
- metric: acc - metric: acc
aggregation: !function utils.manual_accuracy_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
......
from sklearn.metrics import f1_score, accuracy_score from sklearn.metrics import f1_score
def doc_to_text(doc): def doc_to_text(doc):
...@@ -30,12 +30,3 @@ def weighted_f1_score(items): ...@@ -30,12 +30,3 @@ def weighted_f1_score(items):
preds = unzipped_list[1] preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted") fscore = f1_score(golds, preds, average="weighted")
return fscore return fscore
def manual_accuracy_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
accuracy = accuracy_score(golds, preds)
return accuracy
...@@ -5,7 +5,7 @@ group: ...@@ -5,7 +5,7 @@ group:
- afrixnli-translate-test - afrixnli-translate-test
dataset_path: masakhane/afrixnli-translate-test dataset_path: masakhane/afrixnli-translate-test
dataset_name: null dataset_name: null
output_type: generate_until output_type: multiple_choice
test_split: test test_split: test
doc_to_text: !function utils.doc_to_text doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
...@@ -15,23 +15,15 @@ doc_to_choice: ...@@ -15,23 +15,15 @@ doc_to_choice:
- "contradiction" - "contradiction"
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: premise doc_to_decontamination_query: premise
filter_list:
- name: "verbalizer_extract"
filter:
- function: verbalizer
verbalizer_dict: {
"entailment": ['encouragement', 'entitlement', 'entails', 'entailed', 'entailment'],
"contradiction": ['contradictory', 'contradicts', 'contradiction'],
"neutral": ['neutral']}
metric_list: metric_list:
- metric: f1 - metric: f1
aggregation: aggregation: !function utils.weighted_f1_score
average: weighted average: weighted
higher_is_better: True higher_is_better: True
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
- metric: acc - metric: acc
aggregation: !function utils.manual_accuracy_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
......
from sklearn.metrics import f1_score, accuracy_score from sklearn.metrics import f1_score
def doc_to_text(doc): def doc_to_text(doc):
...@@ -30,12 +30,3 @@ def weighted_f1_score(items): ...@@ -30,12 +30,3 @@ def weighted_f1_score(items):
preds = unzipped_list[1] preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted") fscore = f1_score(golds, preds, average="weighted")
return fscore return fscore
def manual_accuracy_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
accuracy = accuracy_score(golds, preds)
return accuracy
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment