Unverified Commit f720ce81 authored by Jess's avatar Jess Committed by GitHub
Browse files

Merge pull request #20 from JessicaOjo/afri_mgsm

Afri mgsm modefied
parents facf38ca c56593ee
group: # This file will be included in the generated language-specific task configs.
- mgsm_direct # It doesn't have a yaml file extension as it is not meant to be imported directly
- afrimgsm # by the harness.
group: afrimgsm_direct
dataset_path: masakhane/afrimgsm dataset_path: masakhane/afrimgsm
dataset_name: null # Overridden by language-specific config.
output_type: generate_until output_type: generate_until
training_split: train training_split: train
test_split: test test_split: test
fewshot_split: train
target_delimiter: "" target_delimiter: ""
doc_to_target: '{% if answer is not none %}{{answer}}{% else %}{{answer_number|string}}{% endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nAnswer:"}}{% else %}{{"Question: "+question+"\nAnswer:"}}{% endif %}'
generation_kwargs: generation_kwargs:
do_sample: false
until: until:
- 'Question:' - "\n\n"
- </s> - "\n"
- <|im_end|> do_sample: false
temperature: 0.0
filter_list: filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
- filter: - filter:
- function: regex - function: regex
group_select: -1 group_select: -1
regex_pattern: (-?[0-9.,]{2,})|(-?[0-9]+) regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
- function: take_first - function: take_first
name: flexible-extract name: flexible-extract
- filter:
- function: regex-numbers
group_select: -1
regex_pattern: (\d{1,10}(?:,\d{3})*(?:[.,]\d{3})?)([^\d()]*)
- function: take_first
name: flexible-extract-new
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
- metric: f1
aggregation: !function utils.weighted_f1_score
average: weighted
higher_is_better: True
ignore_case: true
ignore_punctuation: true
- metric: squad
aggregation: squad_f1
average: weighted
higher_is_better: True
ignore_case: true
ignore_punctuation: true
metadata: metadata:
version: 2.0 version: 2.0
from sklearn.metrics import f1_score
def weighted_f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted")
return fscore
\ No newline at end of file
#!/bin/bash
python utils.py --overwrite --output-dir direct --mode direct
# python utils.py --overwrite --output-dir en_cot --mode en-cot
# python utils.py --overwrite --output-dir native_cot --mode native-cot
File mode changed from 100755 to 100644
import argparse import argparse
import yaml
import yaml
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa'] languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
configs = { LANGUAGES = {}
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:", for lang in languages:
"DIRECT": "Answer:", LANGUAGES[lang] = { # English
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)"} "QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)"}
def add_regex_pattern(regex_pattern):
if regex_pattern is None:
return {}
return {
"filter_list": [
{
"name": "strict-match",
"filter": [
{
"function": "regex",
"regex_pattern": f"""{regex_pattern}""",
},
{
"function": "take_first",
},
],
},
{
"name": "flexible-extract",
"filter": [
{
"function": "regex",
"regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""",
"group_select": -1,
},
{
"function": "take_first",
},
],
},
],
}
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
...@@ -18,19 +55,31 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: ...@@ -18,19 +55,31 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
:param overwrite: Whether to overwrite files if they already exist. :param overwrite: Whether to overwrite files if they already exist.
""" """
err = [] err = []
for lang in languages: for lang in LANGUAGES.keys():
try: try:
QUESTION = LANGUAGES[lang]["QUESTION"]
yaml_template = "cot_yaml"
filter_list = {}
DELIMITER = None
if mode == "direct": if mode == "direct":
ANSWER = LANGUAGES[lang]["DIRECT"]
REGEX = None
task_name = f"afrimgsm_direct_{lang}" task_name = f"afrimgsm_direct_{lang}"
yaml_template = "afrimgsm_common_yaml" yaml_template = "direct_yaml"
elif mode == "native-cot": elif mode == "native-cot":
ANSWER = LANGUAGES[lang]["ANSWER"]
REGEX = LANGUAGES[lang]["REGEX"]
task_name = f"afrimgsm_native_cot_{lang}" task_name = f"afrimgsm_native_cot_{lang}"
yaml_template = "afrimgsm_common_yaml" filter_list = add_regex_pattern(REGEX)
DELIMITER = "" if lang in ["zh", "ja"] else None
elif mode == "en-cot": elif mode == "en-cot":
ANSWER = LANGUAGES["en"]["ANSWER"]
REGEX = LANGUAGES["en"]["REGEX"]
task_name = f"afrimgsm_en_cot_{lang}" task_name = f"afrimgsm_en_cot_{lang}"
yaml_template = "afrimgsm_common_yaml"
file_name = f"{task_name}.yaml" file_name = f"{task_name}.yaml"
ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1
with open( with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8" f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
) as f: ) as f:
...@@ -39,7 +88,23 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: ...@@ -39,7 +88,23 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
{ {
"include": yaml_template, "include": yaml_template,
"dataset_name": lang, "dataset_name": lang,
"task": f"{task_name}" "task": f"{task_name}",
"doc_to_text": f"""{{% if answer is not none %}}"""
f"""{{{{question+"\\n{ANSWER}"}}}}"""
f"""{{% else %}}"""
f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
f"""{{% endif %}}""",
"doc_to_target": f"""{{% if answer is not none %}}"""
f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}"""
f"""{{% else %}}"""
f"""{{{{answer_number|string}}}}"""
f"""{{% endif %}}""",
**filter_list,
"generation_kwargs": {
"until": [QUESTION, "</s>", "<|im_end|>"],
"do_sample": False,
},
**({"target_delimiter": DELIMITER} if DELIMITER else {}),
}, },
f, f,
allow_unicode=True, allow_unicode=True,
...@@ -60,16 +125,16 @@ def main() -> None: ...@@ -60,16 +125,16 @@ def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--overwrite", "--overwrite",
default=True, default=False,
action="store_true", action="store_true",
help="Overwrite files if they already exist", help="Overwrite files if they already exist",
) )
parser.add_argument( parser.add_argument(
"--output-dir", default="./direct", help="Directory to write yaml files to" "--output-dir", default=".", help="Directory to write yaml files to"
) )
parser.add_argument( parser.add_argument(
"--mode", "--mode",
default="direct", default="native-cot",
choices=["direct", "native-cot", "en-cot"], choices=["direct", "native-cot", "en-cot"],
help="Mode of chain-of-thought", help="Mode of chain-of-thought",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment