cwe_utils.py 5.83 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
Baber's avatar
add cwe  
Baber committed
14
import itertools
Baber's avatar
Baber committed
15
16
import random

Baber's avatar
add cwe  
Baber committed
17
import datasets
Baber's avatar
Baber committed
18
19
import wonderwords
from tqdm import tqdm
Baber's avatar
add cwe  
Baber committed
20

Baber's avatar
Baber committed
21
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
Baber's avatar
Baber committed
22
23
24


RNG = random.Random(42)
Baber's avatar
add cwe  
Baber committed
25
26
27
TEMPLATE = "Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?\n\nAnswer: The top 10 words that appear most often in the list are:"


Baber's avatar
Baber committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
r = wonderwords.RandomWord()
WORDS = sorted(
    list(
        set([item for x in ["noun", "adjective", "verb"] for item in r._categories[x]])
    )
)
RNG.shuffle(WORDS)


def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10):
    word_list_full = random.sample(WORDS, num_words)
    common, uncommon = word_list_full[:common_nums], word_list_full[common_nums:]
    word_list = common * int(common_repeats) + uncommon * int(uncommon_repeats)
    RNG.shuffle(word_list)

    # Formatting the word list as "1. word1 2. word2 3. word3 ..."
    context = " ".join([f"{i + 1}. {word}" for i, word in enumerate(word_list)])

    return context, common


Baber's avatar
nit  
Baber committed
49
50
51
52
53
54
55
def generate_input_output(
    num_words: int,
    max_seq_length: int,
    freq_cw: int = 30,
    freq_ucw: int = 3,
    num_cw: int = 10,
):
Baber's avatar
Baber committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    if max_seq_length < 4096:
        context_example, answer_example = get_example(20, 3, 1, num_cw)
        context, answer = get_example(num_words, 6, 1, num_cw)
    else:
        context_example, answer_example = get_example(40, 10, 3, num_cw)
        context, answer = get_example(num_words, freq_cw, freq_ucw, num_cw)

    template = TEMPLATE

    input_example = template.format(
        context=context_example,
        query="",
    ) + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer_example)])

    input_text = template.format(
        context=context,
        query="",
    )

Baber's avatar
nit  
Baber committed
75
    return input_example, input_text, answer
Baber's avatar
Baber committed
76
77
78
79
80


def sys_word_pair_random(
    num_samples: int,
    max_seq_length: int,
Baber's avatar
add cwe  
Baber committed
81
    tokenizer=None,
Baber's avatar
Baber committed
82
83
    incremental: int = 10,
    remove_newline_tab=False,
Baber's avatar
Baber committed
84
    tokens_to_generate=120,
Baber's avatar
Baber committed
85
):
Baber's avatar
add cwe  
Baber committed
86
    assert tokenizer is not None, "Tokenizer is not provided."
Baber's avatar
Baber committed
87
88
89
90
91
92
93
94
    write_jsons = []
    tokens_to_generate = tokens_to_generate

    # Find the perfect num_words
    num_words = incremental

    total_tokens = 0
    while total_tokens + tokens_to_generate < max_seq_length:
Baber's avatar
nit  
Baber committed
95
96
97
        input_example, input_text, answer = generate_input_output(
            num_words, max_seq_length
        )
Baber's avatar
Baber committed
98
99
        # Calculate the number of tokens in the example
        total_tokens = len(
Baber's avatar
add cwe  
Baber committed
100
            tokenizer(
Baber's avatar
nit  
Baber committed
101
102
103
                input_example
                + "\n"
                + input_text
Baber's avatar
Baber committed
104
105
                + " "
                + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
Baber's avatar
add cwe  
Baber committed
106
            ).input_ids
Baber's avatar
Baber committed
107
        )
Baber's avatar
cleanup  
Baber committed
108
109
110
        # print(
        #     f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
        # )
Baber's avatar
Baber committed
111
112
113
114
115
116
117
118
119
        if total_tokens + tokens_to_generate > max_seq_length:
            num_words -= incremental
            break

        num_words += incremental
        if num_words > len(WORDS):
            num_words = len(WORDS)
            break

Baber's avatar
cleanup  
Baber committed
120
    # print("num_words:", num_words)
Baber's avatar
Baber committed
121
122

    # Generate samples
Baber's avatar
cleanup  
Baber committed
123
124
125
    for index in tqdm(
        range(num_samples), desc=f"Generating CWE Samples | {max_seq_length}"
    ):
Baber's avatar
Baber committed
126
127
128
        used_words = num_words
        while True:
            try:
Baber's avatar
nit  
Baber committed
129
130
131
                input_example, input_text, answer = generate_input_output(
                    used_words, max_seq_length
                )
Baber's avatar
add cwe  
Baber committed
132
                length = len(tokenizer(input_text).input_ids) + tokens_to_generate
Baber's avatar
Baber committed
133
134
                assert length <= max_seq_length, f"{length} exceeds max_seq_length."
                break
Baber's avatar
noqa  
Baber committed
135
            except:  # noqa: E722
Baber's avatar
Baber committed
136
137
138
139
140
141
142
                if used_words > incremental:
                    used_words -= incremental

        if remove_newline_tab:
            input_text = " ".join(
                input_text.replace("\n", " ").replace("\t", " ").strip().split()
            )
Baber's avatar
nit  
Baber committed
143
144
145
            input_example = " ".join(
                input_example.replace("\n", " ").replace("\t", " ").strip().split()
            )
Baber's avatar
Baber committed
146

Baber's avatar
add cwe  
Baber committed
147
148
149
        gen_prefix_index = input_text.rfind(" Answer")
        gen_prefix = input_text[gen_prefix_index:].strip()
        input_text = input_text[:gen_prefix_index]
Baber's avatar
Baber committed
150
151
152
        formatted_output = {
            "index": index,
            "input": input_text,
Baber's avatar
add cwe  
Baber committed
153
            "input_example": input_example,
Baber's avatar
Baber committed
154
155
156
            "outputs": answer,
            "length": length,
            "max_length": max_seq_length,
Baber's avatar
add cwe  
Baber committed
157
            "gen_prefix": gen_prefix,
Baber's avatar
Baber committed
158
159
160
161
        }
        write_jsons.append(formatted_output)

    return write_jsons
Baber's avatar
add cwe  
Baber committed
162
163
164
165
166
167
168
169
170
171
172
173


def get_dataset(pretrained, seq=None, **kwargs):
    tokenizer = get_tokenizer(pretrained)
    write_jsons = sys_word_pair_random(
        num_samples=500, max_seq_length=seq, tokenizer=tokenizer
    )
    return write_jsons


def get_cw_dataset(**kwargs):
    pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
Baber's avatar
nit  
Baber committed
174
175
176
177
    df = (
        get_dataset(pretrained, seq=seq)
        for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
    )
Baber's avatar
add cwe  
Baber committed
178
179
180
181
182
183

    return {
        "test": datasets.Dataset.from_list(
            list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
        )
    }