niah_utils.py 3.68 KB
Newer Older
Baber's avatar
cleanup  
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import itertools
from typing import Literal, Union, Generator

import datasets

from lm_eval.tasks.ruler.prepare import generate_samples, get_haystack
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer

TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""


def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
    return {
        "test": datasets.Dataset.from_list(
            list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
        )
    }


# ruff: noqa
niah_single_1 = lambda **kwargs: download_dataset(
    generate_samples(
        get_haystack(type_haystack="repeat"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="repeat",
        type_needle_k="words",
        type_needle_v="numbers",
        TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
    )
    for seq in SEQ_LENGTHS
)
# ruff: noqa
niah_single_2 = lambda **kwargs: download_dataset(
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="numbers",
        TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
    )
    for seq in SEQ_LENGTHS
)
# noqa
niah_single_3 = lambda **kwargs: download_dataset(
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="uuids",
        TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
    )
    for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_1 = lambda **kwargs: download_dataset(
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="numbers",
        num_needle_k=4,
        TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
    )
    for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_2 = lambda **kwargs: download_dataset(
    generate_samples(
        get_haystack(type_haystack="needle"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="needle",
        type_needle_k="words",
        type_needle_v="numbers",
        TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
    )
    for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_3 = lambda **kwargs: download_dataset(
    generate_samples(
        get_haystack(type_haystack="needle"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="needle",
        type_needle_k="uuids",
        type_needle_v="uuids",
        TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
    )
    for seq in SEQ_LENGTHS
)
# noqa
niah_multivalue = lambda **kwargs: download_dataset(
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="numbers",
        num_needle_v=4,
        TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
    )
    for seq in SEQ_LENGTHS
)
# noqa
niah_multiquery = lambda **kwargs: download_dataset(
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="numbers",
        num_needle_q=4,
        TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
    )
    for seq in SEQ_LENGTHS
)