__init__.py 1.68 KB
Newer Older
1
2
3
from lm_eval.logger import eval_logger
from promptsource.templates import DatasetTemplates

4
5
6
7
8
9
10
11
# TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both?
# Prompt library. 
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
PROMPT_REGISTRY = {
    "qa-basic": {
        "question-newline-answer": "Question: {{question}}\nAnswer:",
12
        "q-newline-a": "Q: {{question}}\nA:"
13
14
15
    },
}

16
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
17
    # unpack prompt name 
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    category_name, prompt_name = prompt_id.split(":")
    eval_logger.info(
        f"Loading prompt from {category_name}"
        )
    if category_name == "promptsource":
        try:
            # prompts = DatasetTemplates(dataset_name, dataset_path)
            if subset_name == None:
                prompts = DatasetTemplates(dataset_name=dataset_name)
            else:
                prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
        except:
            raise ValueError(
                f"{dataset_name} and {subset_name} not found"
                )
        if prompt_name in prompts.all_template_names:
            return prompts[prompt_name]
35
        else:
36
37
38
39
40
41
42
43
44
45
46
            raise ValueError(
                f"{prompt_name} not in prompt list {prompts.all_template_names}"
                )
    else:
        try:
            return PROMPT_REGISTRY[category_name][prompt_name]
        except:
            raise ValueError(
                f"expected only a single `:` as separator between \
                prompt category and name, but got `{prompt_id}` instead"
                )
47