Unverified Commit f720ce81 authored by Jess's avatar Jess Committed by GitHub
Browse files

Merge pull request #20 from JessicaOjo/afri_mgsm

Afri mgsm modefied
parents facf38ca c56593ee
group:
- mgsm_direct
- afrimgsm
# This file will be included in the generated language-specific task configs.
# It doesn't have a yaml file extension as it is not meant to be imported directly
# by the harness.
group: afrimgsm_direct
dataset_path: masakhane/afrimgsm
dataset_name: null # Overridden by language-specific config.
output_type: generate_until
training_split: train
test_split: test
fewshot_split: train
target_delimiter: ""
doc_to_target: '{% if answer is not none %}{{answer}}{% else %}{{answer_number|string}}{% endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nAnswer:"}}{% else %}{{"Question: "+question+"\nAnswer:"}}{% endif %}'
generation_kwargs:
do_sample: false
until:
- 'Question:'
- </s>
- <|im_end|>
- "\n\n"
- "\n"
do_sample: false
temperature: 0.0
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
- filter:
- function: regex
group_select: -1
regex_pattern: (-?[0-9.,]{2,})|(-?[0-9]+)
regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
- function: take_first
name: flexible-extract
- filter:
- function: regex-numbers
group_select: -1
regex_pattern: (\d{1,10}(?:,\d{3})*(?:[.,]\d{3})?)([^\d()]*)
- function: take_first
name: flexible-extract-new
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
- metric: f1
aggregation: !function utils.weighted_f1_score
average: weighted
higher_is_better: True
ignore_case: true
ignore_punctuation: true
- metric: squad
aggregation: squad_f1
average: weighted
higher_is_better: True
ignore_case: true
ignore_punctuation: true
metadata:
version: 2.0
from sklearn.metrics import f1_score
def weighted_f1_score(items):
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted")
return fscore
\ No newline at end of file
#!/bin/bash
python utils.py --overwrite --output-dir direct --mode direct
# python utils.py --overwrite --output-dir en_cot --mode en-cot
# python utils.py --overwrite --output-dir native_cot --mode native-cot
File mode changed from 100755 to 100644
import argparse
import yaml
import yaml
languages = ['eng', 'amh', 'ibo', 'fra', 'sna', 'lin', 'wol', 'ewe', 'lug', 'xho', 'kin', 'twi', 'zul', 'orm', 'yor', 'hau', 'sot', 'swa']
configs = {
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)"}
LANGUAGES = {}
for lang in languages:
LANGUAGES[lang] = { # English
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)"}
def add_regex_pattern(regex_pattern):
if regex_pattern is None:
return {}
return {
"filter_list": [
{
"name": "strict-match",
"filter": [
{
"function": "regex",
"regex_pattern": f"""{regex_pattern}""",
},
{
"function": "take_first",
},
],
},
{
"name": "flexible-extract",
"filter": [
{
"function": "regex",
"regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""",
"group_select": -1,
},
{
"function": "take_first",
},
],
},
],
}
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
......@@ -18,19 +55,31 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in languages:
for lang in LANGUAGES.keys():
try:
QUESTION = LANGUAGES[lang]["QUESTION"]
yaml_template = "cot_yaml"
filter_list = {}
DELIMITER = None
if mode == "direct":
ANSWER = LANGUAGES[lang]["DIRECT"]
REGEX = None
task_name = f"afrimgsm_direct_{lang}"
yaml_template = "afrimgsm_common_yaml"
yaml_template = "direct_yaml"
elif mode == "native-cot":
ANSWER = LANGUAGES[lang]["ANSWER"]
REGEX = LANGUAGES[lang]["REGEX"]
task_name = f"afrimgsm_native_cot_{lang}"
yaml_template = "afrimgsm_common_yaml"
filter_list = add_regex_pattern(REGEX)
DELIMITER = "" if lang in ["zh", "ja"] else None
elif mode == "en-cot":
ANSWER = LANGUAGES["en"]["ANSWER"]
REGEX = LANGUAGES["en"]["REGEX"]
task_name = f"afrimgsm_en_cot_{lang}"
yaml_template = "afrimgsm_common_yaml"
file_name = f"{task_name}.yaml"
ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1
with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
) as f:
......@@ -39,7 +88,23 @@ def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
{
"include": yaml_template,
"dataset_name": lang,
"task": f"{task_name}"
"task": f"{task_name}",
"doc_to_text": f"""{{% if answer is not none %}}"""
f"""{{{{question+"\\n{ANSWER}"}}}}"""
f"""{{% else %}}"""
f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
f"""{{% endif %}}""",
"doc_to_target": f"""{{% if answer is not none %}}"""
f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}"""
f"""{{% else %}}"""
f"""{{{{answer_number|string}}}}"""
f"""{{% endif %}}""",
**filter_list,
"generation_kwargs": {
"until": [QUESTION, "</s>", "<|im_end|>"],
"do_sample": False,
},
**({"target_delimiter": DELIMITER} if DELIMITER else {}),
},
f,
allow_unicode=True,
......@@ -60,16 +125,16 @@ def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=True,
default=False,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default="./direct", help="Directory to write yaml files to"
"--output-dir", default=".", help="Directory to write yaml files to"
)
parser.add_argument(
"--mode",
default="direct",
default="native-cot",
choices=["direct", "native-cot", "en-cot"],
help="Mode of chain-of-thought",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment