cwe_utils.py 5.89 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
Baber's avatar
add cwe  
Baber committed
14
import itertools
Baber's avatar
Baber committed
15
import random
Baber's avatar
add cwe  
Baber committed
16
from functools import cache
Baber's avatar
Baber committed
17

Baber's avatar
add cwe  
Baber committed
18
import datasets
Baber's avatar
Baber committed
19
20
import wonderwords
from tqdm import tqdm
Baber's avatar
add cwe  
Baber committed
21
22
23
from transformers import AutoTokenizer

from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
Baber's avatar
Baber committed
24
25
26


RNG = random.Random(42)
Baber's avatar
add cwe  
Baber committed
27
28
29
TEMPLATE = "Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?\n\nAnswer: The top 10 words that appear most often in the list are:"


Baber's avatar
Baber committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
r = wonderwords.RandomWord()
WORDS = sorted(
    list(
        set([item for x in ["noun", "adjective", "verb"] for item in r._categories[x]])
    )
)
RNG.shuffle(WORDS)


def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10):
    word_list_full = random.sample(WORDS, num_words)
    common, uncommon = word_list_full[:common_nums], word_list_full[common_nums:]
    word_list = common * int(common_repeats) + uncommon * int(uncommon_repeats)
    RNG.shuffle(word_list)

    # Formatting the word list as "1. word1 2. word2 3. word3 ..."
    context = " ".join([f"{i + 1}. {word}" for i, word in enumerate(word_list)])

    return context, common


Baber's avatar
nit  
Baber committed
51
52
53
54
55
56
57
def generate_input_output(
    num_words: int,
    max_seq_length: int,
    freq_cw: int = 30,
    freq_ucw: int = 3,
    num_cw: int = 10,
):
Baber's avatar
Baber committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    if max_seq_length < 4096:
        context_example, answer_example = get_example(20, 3, 1, num_cw)
        context, answer = get_example(num_words, 6, 1, num_cw)
    else:
        context_example, answer_example = get_example(40, 10, 3, num_cw)
        context, answer = get_example(num_words, freq_cw, freq_ucw, num_cw)

    template = TEMPLATE

    input_example = template.format(
        context=context_example,
        query="",
    ) + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer_example)])

    input_text = template.format(
        context=context,
        query="",
    )

Baber's avatar
nit  
Baber committed
77
    return input_example, input_text, answer
Baber's avatar
Baber committed
78
79
80
81
82


def sys_word_pair_random(
    num_samples: int,
    max_seq_length: int,
Baber's avatar
add cwe  
Baber committed
83
    tokenizer=None,
Baber's avatar
Baber committed
84
85
    incremental: int = 10,
    remove_newline_tab=False,
Baber's avatar
Baber committed
86
    tokens_to_generate=120,
Baber's avatar
Baber committed
87
):
Baber's avatar
add cwe  
Baber committed
88
    assert tokenizer is not None, "Tokenizer is not provided."
Baber's avatar
Baber committed
89
90
91
92
93
94
95
96
    write_jsons = []
    tokens_to_generate = tokens_to_generate

    # Find the perfect num_words
    num_words = incremental

    total_tokens = 0
    while total_tokens + tokens_to_generate < max_seq_length:
Baber's avatar
nit  
Baber committed
97
98
99
        input_example, input_text, answer = generate_input_output(
            num_words, max_seq_length
        )
Baber's avatar
Baber committed
100
101
        # Calculate the number of tokens in the example
        total_tokens = len(
Baber's avatar
add cwe  
Baber committed
102
            tokenizer(
Baber's avatar
nit  
Baber committed
103
104
105
                input_example
                + "\n"
                + input_text
Baber's avatar
Baber committed
106
107
                + " "
                + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
Baber's avatar
add cwe  
Baber committed
108
            ).input_ids
Baber's avatar
Baber committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        )
        print(
            f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
        )
        if total_tokens + tokens_to_generate > max_seq_length:
            num_words -= incremental
            break

        num_words += incremental
        if num_words > len(WORDS):
            num_words = len(WORDS)
            break

    print("num_words:", num_words)

    # Generate samples
    for index in tqdm(range(num_samples)):
        used_words = num_words
        while True:
            try:
Baber's avatar
nit  
Baber committed
129
130
131
                input_example, input_text, answer = generate_input_output(
                    used_words, max_seq_length
                )
Baber's avatar
add cwe  
Baber committed
132
                length = len(tokenizer(input_text).input_ids) + tokens_to_generate
Baber's avatar
Baber committed
133
134
                assert length <= max_seq_length, f"{length} exceeds max_seq_length."
                break
Baber's avatar
noqa  
Baber committed
135
            except:  # noqa: E722
Baber's avatar
Baber committed
136
137
138
139
140
141
142
                if used_words > incremental:
                    used_words -= incremental

        if remove_newline_tab:
            input_text = " ".join(
                input_text.replace("\n", " ").replace("\t", " ").strip().split()
            )
Baber's avatar
nit  
Baber committed
143
144
145
            input_example = " ".join(
                input_example.replace("\n", " ").replace("\t", " ").strip().split()
            )
Baber's avatar
Baber committed
146

Baber's avatar
add cwe  
Baber committed
147
148
149
        gen_prefix_index = input_text.rfind(" Answer")
        gen_prefix = input_text[gen_prefix_index:].strip()
        input_text = input_text[:gen_prefix_index]
Baber's avatar
Baber committed
150
151
152
        formatted_output = {
            "index": index,
            "input": input_text,
Baber's avatar
add cwe  
Baber committed
153
            "input_example": input_example,
Baber's avatar
Baber committed
154
155
156
            "outputs": answer,
            "length": length,
            "max_length": max_seq_length,
Baber's avatar
add cwe  
Baber committed
157
            "gen_prefix": gen_prefix,
Baber's avatar
Baber committed
158
159
160
161
        }
        write_jsons.append(formatted_output)

    return write_jsons
Baber's avatar
add cwe  
Baber committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186


@cache
def get_tokenizer(pretrained):
    return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)


def get_dataset(pretrained, seq=None, **kwargs):
    tokenizer = get_tokenizer(pretrained)
    write_jsons = sys_word_pair_random(
        num_samples=500, max_seq_length=seq, tokenizer=tokenizer
    )
    return write_jsons


def get_cw_dataset(**kwargs):
    kwargs = kwargs.get("metadata", {})
    pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
    df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)

    return {
        "test": datasets.Dataset.from_list(
            list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
        )
    }