Unverified Commit 01108aca authored by Uanu's avatar Uanu Committed by GitHub
Browse files

Add a new task GPQA (the part CoT and generative) (#1482)



* Add new tasks of GPQA

* Add README

* Remove unused functions

* Remove unused functions

* Linters

* Add flexible match

* update

* Remove deplicate function

* Linter

* update

* Update lm_eval/filters/extraction.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* register multi_choice_regex

* Update

* run precommit

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
Co-authored-by: default avatarhaileyschoelkopf <hailey@eleuther.ai>
parent 8a875e9a
...@@ -15,6 +15,7 @@ FILTER_REGISTRY = { ...@@ -15,6 +15,7 @@ FILTER_REGISTRY = {
"lowercase": transformation.LowercaseFilter, "lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter, "uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter, "map": transformation.MapFilter,
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference. # or should implement different filters for different ways of handling a reward model's inference.
......
import re import re
import sys
import unicodedata
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
...@@ -67,3 +69,115 @@ class WhitespaceFilter(Filter): ...@@ -67,3 +69,115 @@ class WhitespaceFilter(Filter):
filtered_resps = [filter_set(resp) for resp in resps] filtered_resps = [filter_set(resp) for resp in resps]
return filtered_resps return filtered_resps
class MultiChoiceRegexFilter(RegexFilter):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
punct_tbl = dict.fromkeys(
i
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
)
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)
if self.ignore_case:
st = st.lower()
if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st
filtered_resps = []
for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = "A"
without_paren_fallback_regexes = []
without_paren_to_target = {}
choices = doc["choices"]
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"
without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"
next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})"
)
filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(
fallback_regex, filter_ignores(resp), choice_to_alpha
)
if not match:
match = find_match(
without_paren_fallback_regex, resp, without_paren_to_target
)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)
return filtered_resps
...@@ -35,6 +35,9 @@ This dataset is gated, so you will have to accept the terms of use at https://hu ...@@ -35,6 +35,9 @@ This dataset is gated, so you will have to accept the terms of use at https://hu
* `gpqa_{main, diamond, extended}_zeroshot` * `gpqa_{main, diamond, extended}_zeroshot`
* `gpqa_{main, diamond, extended}_n_shot` * `gpqa_{main, diamond, extended}_n_shot`
* `gpqa_{main, diamond, extended}_generative_n_shot`
* `gpqa_{main, diamond, extended}_cot_zeroshot`
* `gpqa_{main, diamond, extended}_cot_n_shot`
### Checklist ### Checklist
......
import yaml
from tqdm import tqdm
def main() -> None:
subset = ["extended", "diamond", "main"]
setting = "cot_n_shot"
for task in tqdm(subset):
file_name = f"gpqa_{task}_{setting}.yaml"
try:
with open(f"{file_name}", "w") as f:
f.write("# Generated by _generate_configs.py\n")
yaml.dump(
{
"include": f"_gpqa_{setting}_yaml",
"task": f"gpqa_{task}_{setting}",
"dataset_name": f"gpqa_{task}",
},
f,
)
except FileExistsError:
pass
if __name__ == "__main__":
main()
dataset_path: Idavidrein/gpqa
group: gpqa
output_type: generate_until
process_docs: !function utils.process_docs
training_split: train
# Because huggingface dataset only has train split
validation_split: train
test_split: null
description: "Here are some example questions from experts. Answer the final question yourself, following the format of the previous questions exactly.\n"
doc_to_text: "Question: {{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nLet's think step by step: "
doc_to_target: answer
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "(?<=The answer is )(.*)(?=.)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "multi_choice_regex"
group_select: -1
ignore_case: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
- function: "take_first"
generation_kwargs:
until:
- "</s>"
do_sample: false
temperature: 0.0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 1.0
# Generated by _generate_configs.py
dataset_name: gpqa_diamond
include: _gpqa_cot_n_shot_yaml
task: gpqa_diamond_cot_n_shot
# Generated by _generate_configs.py
dataset_name: gpqa_extended
include: _gpqa_cot_n_shot_yaml
task: gpqa_extended_cot_n_shot
# Generated by _generate_configs.py
dataset_name: gpqa_main
include: _gpqa_cot_n_shot_yaml
task: gpqa_main_cot_n_shot
import random
import re
import datasets
def preprocess(text):
if text is None:
return " "
text = text.strip()
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
choices = [
preprocess(doc["Incorrect Answer 1"]),
preprocess(doc["Incorrect Answer 2"]),
preprocess(doc["Incorrect Answer 3"]),
preprocess(doc["Correct Answer"]),
]
random.shuffle(choices)
correct_answer_index = choices.index(preprocess(doc["Correct Answer"]))
out_doc = {
"choice1": choices[0],
"choice2": choices[1],
"choice3": choices[2],
"choice4": choices[3],
"choices": [choices[0], choices[1], choices[2], choices[3]],
"answer": f"({chr(65 + correct_answer_index)})",
}
return out_doc
return dataset.map(_process_doc)
import yaml
from tqdm import tqdm
def main() -> None:
subset = ["extended", "diamond", "main"]
setting = "cot_zeroshot"
for task in tqdm(subset):
file_name = f"gpqa_{task}_{setting}.yaml"
try:
with open(f"{file_name}", "w") as f:
f.write("# Generated by _generate_configs.py\n")
yaml.dump(
{
"include": f"_gpqa_{setting}_yaml",
"task": f"gpqa_{task}_{setting}",
"dataset_name": f"gpqa_{task}",
},
f,
)
except FileExistsError:
pass
if __name__ == "__main__":
main()
dataset_path: Idavidrein/gpqa
group: gpqa
output_type: generate_until
process_docs: !function utils.process_docs
training_split: train
# Because huggingface dataset only has train split
validation_split: train
test_split: null
doc_to_text: "What is the correct answer to this question:{{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nLet's think step by step: "
doc_to_target: answer
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "(?<=The answer is )(.*)(?=.)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "multi_choice_regex"
group_select: -1
ignore_case: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
- function: "take_first"
generation_kwargs:
until:
- "</s>"
do_sample: false
temperature: 0.0
num_fewshot: 0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 1.0
# Generated by _generate_configs.py
dataset_name: gpqa_diamond
include: _gpqa_cot_zeroshot_yaml
task: gpqa_diamond_cot_zeroshot
# Generated by _generate_configs.py
dataset_name: gpqa_extended
include: _gpqa_cot_zeroshot_yaml
task: gpqa_extended_cot_zeroshot
# Generated by _generate_configs.py
dataset_name: gpqa_main
include: _gpqa_cot_zeroshot_yaml
task: gpqa_main_cot_zeroshot
import random
import re
import datasets
def preprocess(text):
if text is None:
return " "
text = text.strip()
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc):
choices = [
preprocess(doc["Incorrect Answer 1"]),
preprocess(doc["Incorrect Answer 2"]),
preprocess(doc["Incorrect Answer 3"]),
preprocess(doc["Correct Answer"]),
]
random.shuffle(choices)
correct_answer_index = choices.index(preprocess(doc["Correct Answer"]))
out_doc = {
"choice1": choices[0],
"choice2": choices[1],
"choice3": choices[2],
"choice4": choices[3],
"choices": [choices[0], choices[1], choices[2], choices[3]],
"answer": f"({chr(65 + correct_answer_index)})",
}
return out_doc
return dataset.map(_process_doc)
import yaml
from tqdm import tqdm
def main() -> None:
subset = ["extended", "diamond", "main"]
setting = "generative_n_shot"
for task in tqdm(subset):
file_name = f"gpqa_{task}_{setting}.yaml"
try:
with open(f"{file_name}", "w") as f:
f.write("# Generated by _generate_configs.py\n")
yaml.dump(
{
"include": f"_gpqa_{setting}_yaml",
"task": f"gpqa_{task}_{setting}",
"dataset_name": f"gpqa_{task}",
},
f,
)
except FileExistsError:
pass
if __name__ == "__main__":
main()
dataset_path: Idavidrein/gpqa
group: gpqa
output_type: generate_until
process_docs: !function utils.process_docs
training_split: train
# Because huggingface dataset only has train split
validation_split: train
test_split: null
description: "Here are some example questions from experts. Answer the final question yourself, following the format of the previous questions exactly.\n"
doc_to_text: "Question: {{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nAnswer:"
doc_to_target: answer
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "(?<=The answer is )(.*)(?=.)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "multi_choice_regex"
group_select: -1
ignore_case: true
ignore_punctuation: true
regex_pattern: "(\\([A-Z]\\))"
- function: "take_first"
generation_kwargs:
until:
- "</s>"
- "Question:"
- "<|im_end|>"
temperature: 0.0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
metadata:
version: 1.0
# Generated by _generate_configs.py
dataset_name: gpqa_diamond
include: _gpqa_generative_n_shot_yaml
task: gpqa_diamond_generative_n_shot
# Generated by _generate_configs.py
dataset_name: gpqa_extended
include: _gpqa_generative_n_shot_yaml
task: gpqa_extended_generative_n_shot
# Generated by _generate_configs.py
dataset_name: gpqa_main
include: _gpqa_generative_n_shot_yaml
task: gpqa_main_generative_n_shot
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment