__init__.py 4.45 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import ast
Lintang Sutawika's avatar
Lintang Sutawika committed
2
import logging
3
import os
4
from typing import Dict
5

6
import lm_eval.tasks
Baber's avatar
Baber committed
7
import lm_eval.utils
8
from lm_eval import utils
9

10

Lintang Sutawika's avatar
Lintang Sutawika committed
11
12
eval_logger = logging.getLogger(__name__)

lintangsutawika's avatar
lintangsutawika committed
13
# Prompt library.
14
15
16
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
17
PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
18
19
    "qa-basic": {
        "question-newline-answer": "Question: {{question}}\nAnswer:",
lintangsutawika's avatar
lintangsutawika committed
20
        "q-newline-a": "Q: {{question}}\nA:",
21
22
23
    },
}

lintangsutawika's avatar
lintangsutawika committed
24

Ethan Smith's avatar
Ethan Smith committed
25
def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
lintangsutawika's avatar
lintangsutawika committed
26
    # unpack prompt name
27
    category_name, prompt_name = prompt_id.split(":")
lintangsutawika's avatar
update  
lintangsutawika committed
28
29
30
31
32
    if subset_name is None:
        dataset_full_name = dataset_name
    else:
        dataset_full_name = f"{dataset_name}-{subset_name}"
    eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
33
    if category_name == "promptsource":
34
        try:
35
            from promptsource.templates import DatasetTemplates
36
37
        except ModuleNotFoundError as exception:
            raise type(exception)(
38
39
40
                "Tried to load a Promptsource template, but promptsource is not installed ",
                "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]",
            )
41
        try:
lintangsutawika's avatar
lintangsutawika committed
42
            if subset_name is None:
43
44
                prompts = DatasetTemplates(dataset_name=dataset_name)
            else:
lintangsutawika's avatar
lintangsutawika committed
45
46
                prompts = DatasetTemplates(
                    dataset_name=dataset_name, subset_name=subset_name
47
                )
lintangsutawika's avatar
lintangsutawika committed
48
49
        except Exception:
            raise ValueError(f"{dataset_name} and {subset_name} not found")
50
51
        if prompt_name in prompts.all_template_names:
            return prompts[prompt_name]
52
        else:
53
54
            raise ValueError(
                f"{prompt_name} not in prompt list {prompts.all_template_names}"
lintangsutawika's avatar
lintangsutawika committed
55
            )
56
57
58
59
60
61
62
63
    elif ".yaml" in category_name:
        import yaml

        with open(category_name, "rb") as file:
            prompt_yaml_file = yaml.full_load(file)

        prompt_string = prompt_yaml_file["prompts"][prompt_name]
        return PromptString(prompt_string)
64
65
66
    else:
        try:
            return PROMPT_REGISTRY[category_name][prompt_name]
lintangsutawika's avatar
lintangsutawika committed
67
        except Exception:
68
69
70
            raise ValueError(
                f"expected only a single `:` as separator between \
                prompt category and name, but got `{prompt_id}` instead"
lintangsutawika's avatar
lintangsutawika committed
71
            )
72
73


lintangsutawika's avatar
lintangsutawika committed
74
75
76
def load_prompt_list(
    use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
):
77
    category_name, prompt_name = use_prompt.split(":")
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if category_name == "promptsource":
        from promptsource.templates import DatasetTemplates

        if subset_name is None:
            prompts = DatasetTemplates(dataset_name=dataset_name)
        else:
            prompts = DatasetTemplates(
                dataset_name=dataset_name, subset_name=subset_name
            )

        prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)

    elif ".yaml" in category_name:
        import yaml

lintangsutawika's avatar
lintangsutawika committed
94
95
        if yaml_path is not None:
            category_name = os.path.realpath(os.path.join(yaml_path, category_name))
lintangsutawika's avatar
lintangsutawika committed
96

97
98
99
100
101
102
        with open(category_name, "rb") as file:
            prompt_yaml_file = yaml.full_load(file)

        prompt_list = utils.pattern_match(
            prompt_name, prompt_yaml_file["prompts"].keys()
        )
103

lintangsutawika's avatar
lintangsutawika committed
104
    # category_name, *prompt_name = use_prompt.split(":")
lintangsutawika's avatar
lintangsutawika committed
105
106
107
108
109
110
    # TODO allow to multiple prompt naming
    # if len(prompt_name) > 1:
    #     prompt_list = []
    #     for prompt in prompt_name:
    #         prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
    # else:
lintangsutawika's avatar
lintangsutawika committed
111
    #     prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
112
    return [":".join([category_name, prompt]) for prompt in prompt_list]
113
114
115


class PromptString:
lintangsutawika's avatar
lintangsutawika committed
116
    def __init__(self, prompt_string):
117
118
119
120
121
        self.prompt_string = prompt_string

    def apply(self, doc):
        doc_to_text = self.prompt_string["doc_to_text"]
        doc_to_target = self.prompt_string["doc_to_target"]
lintangsutawika's avatar
lintangsutawika committed
122
123
124

        # TODO need a way to process doc_to_choice
        if "doc_to_choice" in self.prompt_string:
125
            raise NotImplementedError("Not yet implemented to accept doc_to_choice")
lintangsutawika's avatar
lintangsutawika committed
126

Baber's avatar
Baber committed
127
128
        text_string = lm_eval.utils.apply_template(doc_to_text, doc)
        target_string = lm_eval.utils.apply_template(doc_to_target, doc)
129
130

        return [text_string, target_string]