utils.py 6.24 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
# noqa
import itertools
import re
Baber's avatar
Baber committed
4
from functools import cache
Baber's avatar
Baber committed
5
from typing import Literal, Generator, Union, TYPE_CHECKING
Baber's avatar
Baber committed
6
7
8
9

import datasets
from transformers import AutoTokenizer

Baber's avatar
Baber committed
10
from lm_eval.tasks.ruler.essays import get_all_essays
Baber's avatar
Baber committed
11
12
from lm_eval.tasks.ruler.prepare import generate_samples

Baber's avatar
Baber committed
13
14
if TYPE_CHECKING:
    import transformers
Baber's avatar
Baber committed
15

Baber's avatar
Baber committed
16
17
18
19
20
21
22

def get_tokenizer(
    **kwargs,
) -> Union["transformers.PreTrainedTokenizer", "transformers.PreTrainedTokenizerFast"]:
    kwargs = kwargs.get("metadata", {})
    pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
    assert pretrained, "No tokenizer or pretrained provided."
Baber's avatar
nit  
Baber committed
23
    print("using tokenizer ", pretrained)
Baber's avatar
Baber committed
24
    return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
Baber's avatar
Baber committed
25
26


Baber's avatar
Baber committed
27
28
29
# TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are"""
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""

Baber's avatar
Baber committed
30
31

SEQ_LENGTHS = (
Baber's avatar
Baber committed
32
33
34
    # 131072,
    # 65536,
    # 32768,
Baber's avatar
Baber committed
35
36
    16384,
    8192,
Baber's avatar
Baber committed
37
38
39
40
41
42
43
44
45
    4096,
)

NUM_SAMPLES = 500
REMOVE_NEWLINE_TAB = ""
STOP_WORDS = ""
RANDOM_SEED = 42


Baber's avatar
Baber committed
46
@cache
Baber's avatar
Baber committed
47
48
49
def get_haystack(
    type_haystack: Literal["essay", "repeat", "needle"],
) -> Union[list[str], str]:
Baber's avatar
Baber committed
50
51
    NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
    if type_haystack == "essay":
Baber's avatar
Baber committed
52
        essay = get_all_essays()["text"]
Baber's avatar
Baber committed
53
54
55
56
57
58
59
60
61
62
63
        # essay = json.load(open(essay))["text"]
        haystack = re.sub(r"\s+", " ", essay).split(" ")
    elif type_haystack == "repeat":
        haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
    elif type_haystack == "needle":
        haystack = NEEDLE
    else:
        raise NotImplementedError(f"{type_haystack} is not implemented.")
    return haystack


Baber's avatar
Baber committed
64
def flatten(df: Generator) -> dict[str, datasets.Dataset]:
Baber's avatar
Baber committed
65
66
67
68
69
70
71
72
    return {
        "test": datasets.Dataset.from_list(
            list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
        )
    }


# ruff: noqa
Baber's avatar
Baber committed
73
niah_single_1 = lambda **kwargs: flatten(
Baber's avatar
Baber committed
74
75
76
77
78
79
80
    generate_samples(
        get_haystack(type_haystack="repeat"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="repeat",
        type_needle_k="words",
        type_needle_v="numbers",
Baber's avatar
Baber committed
81
        TOKENIZER=get_tokenizer(**kwargs),
Baber's avatar
Baber committed
82
83
84
85
    )
    for seq in SEQ_LENGTHS
)
# ruff: noqa
Baber's avatar
Baber committed
86
niah_single_2 = lambda x: flatten(
Baber's avatar
Baber committed
87
88
89
90
91
92
93
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="numbers",
Baber's avatar
Baber committed
94
        TOKENIZER=get_tokenizer(x),
Baber's avatar
Baber committed
95
96
97
98
    )
    for seq in SEQ_LENGTHS
)
# noqa
Baber's avatar
Baber committed
99
niah_single_3 = lambda **kwargs: flatten(
Baber's avatar
Baber committed
100
101
102
103
104
105
106
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="uuids",
Baber's avatar
Baber committed
107
        TOKENIZER=get_tokenizer(**kwargs),
Baber's avatar
Baber committed
108
109
110
111
    )
    for seq in SEQ_LENGTHS
)
# noqa
Baber's avatar
Baber committed
112
niah_multikey_1 = lambda **kwargs: flatten(
Baber's avatar
Baber committed
113
114
115
116
117
118
119
120
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="numbers",
        num_needle_k=4,
Baber's avatar
Baber committed
121
        TOKENIZER=get_tokenizer(**kwargs),
Baber's avatar
Baber committed
122
123
124
125
    )
    for seq in SEQ_LENGTHS
)
# noqa
Baber's avatar
Baber committed
126
niah_multikey_2 = lambda **kwargs: flatten(
Baber's avatar
Baber committed
127
128
129
130
131
132
133
    generate_samples(
        get_haystack(type_haystack="needle"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="needle",
        type_needle_k="words",
        type_needle_v="numbers",
Baber's avatar
Baber committed
134
        TOKENIZER=get_tokenizer(**kwargs),
Baber's avatar
Baber committed
135
136
137
138
    )
    for seq in SEQ_LENGTHS
)
# noqa
Baber's avatar
Baber committed
139
niah_multikey_3 = lambda **kwargs: flatten(
Baber's avatar
Baber committed
140
141
142
143
144
145
146
    generate_samples(
        get_haystack(type_haystack="needle"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="needle",
        type_needle_k="uuids",
        type_needle_v="uuids",
Baber's avatar
Baber committed
147
        TOKENIZER=get_tokenizer(**kwargs),
Baber's avatar
Baber committed
148
149
150
151
    )
    for seq in SEQ_LENGTHS
)
# noqa
Baber's avatar
Baber committed
152
niah_multivalue = lambda **kwargs: flatten(
Baber's avatar
Baber committed
153
154
155
156
157
158
159
160
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="numbers",
        num_needle_v=4,
Baber's avatar
Baber committed
161
        TOKENIZER=get_tokenizer(**kwargs),
Baber's avatar
Baber committed
162
163
164
165
    )
    for seq in SEQ_LENGTHS
)
# noqa
Baber's avatar
Baber committed
166
niah_multiquery = lambda **kwargs: flatten(
Baber's avatar
Baber committed
167
168
169
170
171
172
173
174
    generate_samples(
        get_haystack(type_haystack="essay"),
        max_seq_length=seq,
        template=TEMPLATE,
        type_haystack="essay",
        type_needle_k="words",
        type_needle_v="numbers",
        num_needle_q=4,
Baber's avatar
Baber committed
175
        TOKENIZER=get_tokenizer(**kwargs),
Baber's avatar
Baber committed
176
177
178
179
180
    )
    for seq in SEQ_LENGTHS
)


Baber's avatar
Baber committed
181
def postprocess_pred(predict_str: str) -> str:
Baber's avatar
Baber committed
182
183
184
185
186
187
188
189
190
    predict_str = predict_str.strip()

    # Remove all non-printable characters
    np_pattern = re.compile(r"[\x00-\x1f]")
    predict_str = np_pattern.sub("\n", predict_str).strip()

    return predict_str


Baber's avatar
Baber committed
191
192
193
194
195
196
197
198
199
200
def string_match_all(preds: list[str], refs: list[list[str]]) -> float:
    score = sum(
        [
            sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
            for pred, ref in zip(preds, refs)
        ]
    ) / len(preds)
    return score


Baber's avatar
Baber committed
201
202
def process_results(doc: dict, results: list[str]) -> dict[str, float]:
    # hacky: set all other lengths to -1
Baber's avatar
Baber committed
203
204
    metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
    input_len = doc["max_length"]
Baber's avatar
Baber committed
205
206
207
    pred = postprocess_pred(results[0])
    score = string_match_all([pred], [doc["outputs"]])
    metrics[str(input_len)] = score
Baber's avatar
Baber committed
208
209
210
    return metrics


Baber's avatar
Baber committed
211
def aggregate_metrics(metrics: list[float]) -> float:
Baber's avatar
Baber committed
212
213
214
215
216
    res = [x for x in metrics if x != -1]
    if not res:
        # we don't have any samples with this length
        return 0.0
    return sum(res) / len(res)