Commit ab96fc7e authored by lintangsutawika's avatar lintangsutawika
Browse files

merged with latest update

parents bf2517cc 8680e938
"dataset_name": "history"
"include": "_default_haerae_yaml"
"task": "haerae_history"
"dataset_name": "loan_words"
"include": "_default_haerae_yaml"
"task": "haerae_loan_word"
"dataset_name": "rare_words"
"include": "_default_haerae_yaml"
"task": "haerae_rare_word"
"dataset_name": "standard_nomenclature"
"include": "_default_haerae_yaml"
"task": "haerae_standard_nomenclature"
......@@ -5,14 +5,24 @@ output_type: generate_until
doc_to_text: "Q: {{question.strip()}}\n(A) {{choices[0]}} (B) {{choices[1]}} (C) {{choices[2]}} (D) {{choices[3]}}\nA: Let's think step by step."
doc_to_target: "{{['(A)', '(B)', '(C)', '(D)'][answer]}}"
filter_list:
- name: "get-answer"
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: !function utils.MultiChoiceRegexFilter
group_select: -1
ignore_case: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
- function: "take_first"
generation_kwargs:
until:
- "</s>"
- "Q:"
- "<|im_end|>"
do_sample: false
temperature: 0.0
num_fewshot: 0
......@@ -23,4 +33,4 @@ metric_list:
ignore_case: true
ignore_punctuation: true
metadata:
version: 0.0
version: 1.0
import re
import sys
import unicodedata
from lm_eval.filters.extraction import RegexFilter
class MultiChoiceRegexFilter(RegexFilter):
""" """
def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]",
ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith('P'))
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)
if self.ignore_case:
st = st.lower()
if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st
filtered_resps = []
for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = 'A'
without_paren_fallback_regexes = []
without_paren_to_target = {}
choices = doc['choices']
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"
without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"
next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile('|'.join(fallback_regexes))
without_paren_fallback_regex = '|'.join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})")
filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha)
if not match:
match = find_match(without_paren_fallback_regex, resp, without_paren_to_target)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)
return filtered_resps
......@@ -7,13 +7,27 @@ fewshot_config:
output_type: generate_until
doc_to_text: "Q: {{question.strip()}}\n(A) {{choices[0]}} (B) {{choices[1]}} (C) {{choices[2]}} (D) {{choices[3]}}\nA:"
doc_to_target: "{{['(A)', '(B)', '(C)', '(D)'][answer]}}"
filter_list:
- name: "strict-match"
filter:
- function: "take_first"
- name: "flexible-extract"
filter:
- function: !function utils.MultiChoiceRegexFilter
group_select: 0
regex_pattern: "(\\([A-Z]\\))"
ignore_case: true
ignore_punctuation: true
- function: "take_first"
generation_kwargs:
until:
- "</s>"
- "Q:"
- "<|im_end|>"
- "<0x0A>"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
version: 1.0
import re
import sys
import unicodedata
from lm_eval.filters.extraction import RegexFilter
class MultiChoiceRegexFilter(RegexFilter):
""" """
def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]",
ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith('P'))
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)
if self.ignore_case:
st = st.lower()
if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st
filtered_resps = []
for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = 'A'
without_paren_fallback_regexes = []
without_paren_to_target = {}
choices = doc['choices']
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"
without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"
next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile('|'.join(fallback_regexes))
without_paren_fallback_regex = '|'.join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})")
filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(fallback_regex, filter_ignores(resp), choice_to_alpha)
if not match:
match = find_match(without_paren_fallback_regex, resp, without_paren_to_target)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)
return filtered_resps
group:
- m_mmlu
dataset_path: alexandrainst/m_mmlu
test_split: test
fewshot_split: train
fewshot_config:
sampler: first_n
output_type: multiple_choice
doc_to_text: "{{instruction.strip()}}\nA. {{option_a}}\nB. {{option_b}}\nC. {{option_c}}\nD. {{option_d}}\nAnswer:"
doc_to_choice: ["A", "B", "C", "D"]
doc_to_target: answer
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
import yaml
import datasets
from tqdm import tqdm
def main() -> None:
dataset_path = "alexandrainst/m_mmlu"
# Removed hy and sk subdataset because the original dataset is broken
# I created this PR https://huggingface.co/datasets/alexandrainst/m_mmlu/discussions/3
# on the dataset for the authors, in case it will be accepeted the filter can be removed
keys_without_hy_sk = list(filter(lambda k: ('hy' not in k and 'sk' not in k),
datasets.get_dataset_infos(dataset_path).keys()))
for task in tqdm():
file_name = f"m_mmlu_{task}.yaml"
try:
with open(f"{file_name}", "w") as f:
f.write("# Generated by _generate_configs.py\n")
yaml.dump(
{
"include": "_default_yaml",
"task": f"{dataset_path.split('/')[-1]}_{task}",
"dataset_name": task,
},
f,
)
except FileExistsError:
pass
if __name__ == "__main__":
main()
# Generated by _generate_configs.py
dataset_name: ar
include: _default_yaml
task: m_mmlu_ar
# Generated by _generate_configs.py
dataset_name: bn
include: _default_yaml
task: m_mmlu_bn
# Generated by _generate_configs.py
dataset_name: ca
include: _default_yaml
task: m_mmlu_ca
# Generated by _generate_configs.py
dataset_name: da
include: _default_yaml
task: m_mmlu_da
# Generated by _generate_configs.py
dataset_name: de
include: _default_yaml
task: m_mmlu_de
# Generated by _generate_configs.py
dataset_name: en
include: _default_yaml
task: m_mmlu_en
# Generated by _generate_configs.py
dataset_name: es
include: _default_yaml
task: m_mmlu_es
# Generated by _generate_configs.py
dataset_name: eu
include: _default_yaml
task: m_mmlu_eu
# Generated by _generate_configs.py
dataset_name: fr
include: _default_yaml
task: m_mmlu_fr
# Generated by _generate_configs.py
dataset_name: gu
include: _default_yaml
task: m_mmlu_gu
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment