utils.py 3.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import re
from functools import partial

from lm_eval.api.filter import Filter

choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"]

def format_cot_example(example, including_answer=True):
    prompt = "Question:\n"
    question = example["question"]
    options = example["options"]
    prompt += question + "\n"
    prompt += "Options:\n"
    for i, opt in enumerate(options):
        prompt += "{}. {}\n".format(choices[i], opt)
    if including_answer:
        cot_content = example["cot_content"].replace("A: Let's think step by step.",
                                                     "Answer: Let's think step by step.")
        prompt += cot_content + "\n\n"
    else:
        prompt += "Answer: Let's think step by step."
    return prompt


doc_to_text = partial(format_cot_example, including_answer=False)
fewshot_to_text = partial(format_cot_example, including_answer=True)


def process_docs(dataset, subject):
    return dataset.filter(lambda x: x["category"] == subject)

process_biology = partial(process_docs, subject="biology")
process_business = partial(process_docs, subject="business")
process_chemistry = partial(process_docs, subject="chemistry")
process_computer_science = partial(process_docs, subject="computer_science")
process_economics = partial(process_docs, subject="economics")
process_engineering = partial(process_docs, subject="engineering")
process_health = partial(process_docs, subject="health")
process_history = partial(process_docs, subject="history")
process_law = partial(process_docs, subject="law")
process_math = partial(process_docs, subject="math")
process_other = partial(process_docs, subject="other")
process_philosophy = partial(process_docs, subject="philosophy")
process_physics = partial(process_docs, subject="physics")
process_psychology = partial(process_docs, subject="psychology")


# def generate_cot_prompt(val_df, curr, k):
#     prompt = ""
#     with open(f"cot_prompt_lib/initial_prompt.txt", "r") as fi:
#         for line in fi.readlines():
#             prompt += line
#     subject = curr["category"]
#     val_df = select_by_category(val_df, subject)
#     val_df = val_df[: k]
#     prompt = prompt.replace("{$}", subject) + "\n"
#     for example in val_df:
#         prompt += format_cot_example(example, including_answer=True)
#     prompt += format_cot_example(curr, including_answer=False)
#     return prompt

class CustomRegexFilter(Filter):
    """ """

    def __init__(
        self,
        regex_pattern: list = [r"answer is \(?([ABCDEFGHIJ])\)?", r".*[aA]nswer:\s*([A-J])"],
        group_select=0,
        fallback: str = "[invalid]",
    ) -> None:
        """
        pass a string `regex` to run `re.compile(r"regex")` on.
        `fallback` defines the output returned if no matches for the regex are located.
        """
        self.regex_pattern = regex_pattern
        self.regex = [re.compile(pattern) for pattern in regex_pattern]
        self.group_select = group_select
        self.fallback = fallback

    def apply(self, resps, docs):
        # here, we assume we have a list, in which each element is
        # a list of model responses for some particular input/target pair.
        # so we process each of these (same input/target response sets)
        # independently (and keep them a list.)
        filtered_resps = []
        for resp in resps:
            for pattern in self.regex:
                match = pattern.search(resp)
                if match:
                    filtered_resps.append(match.group(1))
                    break
        
        if len(filtered_resps) == 0:
            filtered_resps = [None]

        return filtered_resps