__init__.py 2.83 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
2
import ast

3
from lm_eval import utils
4
5
from lm_eval.logger import eval_logger

lintangsutawika's avatar
lintangsutawika committed
6
# Prompt library.
7
8
9
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
Ethan Smith's avatar
Ethan Smith committed
10
PROMPT_REGISTRY: dict[str, dict[str, str]] = {
11
12
    "qa-basic": {
        "question-newline-answer": "Question: {{question}}\nAnswer:",
lintangsutawika's avatar
lintangsutawika committed
13
        "q-newline-a": "Q: {{question}}\nA:",
14
15
16
    },
}

lintangsutawika's avatar
lintangsutawika committed
17

Ethan Smith's avatar
Ethan Smith committed
18
def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
lintangsutawika's avatar
lintangsutawika committed
19
    # unpack prompt name
20
    category_name, prompt_name = prompt_id.split(":")
lintangsutawika's avatar
update  
lintangsutawika committed
21
22
23
24
25
    if subset_name is None:
        dataset_full_name = dataset_name
    else:
        dataset_full_name = f"{dataset_name}-{subset_name}"
    eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
26
    if category_name == "promptsource":
27
        try:
28
            from promptsource.templates import DatasetTemplates
29
        except ModuleNotFoundError:
30
31
32
33
            raise Exception(
                "Tried to load a Promptsource template, but promptsource is not installed ",
                "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]",
            )
34
        try:
lintangsutawika's avatar
lintangsutawika committed
35
            if subset_name is None:
36
37
                prompts = DatasetTemplates(dataset_name=dataset_name)
            else:
lintangsutawika's avatar
lintangsutawika committed
38
39
                prompts = DatasetTemplates(
                    dataset_name=dataset_name, subset_name=subset_name
40
                )
lintangsutawika's avatar
lintangsutawika committed
41
42
        except Exception:
            raise ValueError(f"{dataset_name} and {subset_name} not found")
43
44
        if prompt_name in prompts.all_template_names:
            return prompts[prompt_name]
45
        else:
46
47
            raise ValueError(
                f"{prompt_name} not in prompt list {prompts.all_template_names}"
lintangsutawika's avatar
lintangsutawika committed
48
            )
49
50
51
    else:
        try:
            return PROMPT_REGISTRY[category_name][prompt_name]
lintangsutawika's avatar
lintangsutawika committed
52
        except Exception:
53
54
55
            raise ValueError(
                f"expected only a single `:` as separator between \
                prompt category and name, but got `{prompt_id}` instead"
lintangsutawika's avatar
lintangsutawika committed
56
            )
57
58
59
60
61
62
63
64
65
66
67


def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwargs):

    from promptsource.templates import DatasetTemplates

    if subset_name is None:
        prompts = DatasetTemplates(dataset_name=dataset_name)
    else:
        prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)

lintangsutawika's avatar
lintangsutawika committed
68
    category_name, *prompt_name = use_prompt.split(":")
lintangsutawika's avatar
lintangsutawika committed
69
70
71
72
73
74
    # TODO allow to multiple prompt naming
    # if len(prompt_name) > 1:
    #     prompt_list = []
    #     for prompt in prompt_name:
    #         prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
    # else:
75
76
    prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
    return [":".join([category_name, prompt]) for prompt in prompt_list]