cwe_utils.py 5.95 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
Baber's avatar
add cwe  
Baber committed
14
import itertools
Baber's avatar
Baber committed
15
16
import random

Baber's avatar
add cwe  
Baber committed
17
import datasets
Baber's avatar
Baber committed
18
19
import wonderwords
from tqdm import tqdm
Baber's avatar
add cwe  
Baber committed
20

Baber's avatar
Baber committed
21
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
Baber's avatar
Baber committed
22
23


Baber's avatar
fixup  
Baber committed
24
25
26
27
28
29
CONFIG = {
    "tokens_to_generate": 120,
    "template": """Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?""",
    "answer_prefix": """ Answer: The top 10 words that appear most often in the list are:""",
}

Baber's avatar
Baber committed
30
RNG = random.Random(42)
Baber's avatar
fixup  
Baber committed
31
TEMPLATE = CONFIG["template"] + CONFIG["answer_prefix"]
Baber's avatar
add cwe  
Baber committed
32
33


Baber's avatar
Baber committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
r = wonderwords.RandomWord()
WORDS = sorted(
    list(
        set([item for x in ["noun", "adjective", "verb"] for item in r._categories[x]])
    )
)
RNG.shuffle(WORDS)


def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10):
    word_list_full = random.sample(WORDS, num_words)
    common, uncommon = word_list_full[:common_nums], word_list_full[common_nums:]
    word_list = common * int(common_repeats) + uncommon * int(uncommon_repeats)
    RNG.shuffle(word_list)

    # Formatting the word list as "1. word1 2. word2 3. word3 ..."
    context = " ".join([f"{i + 1}. {word}" for i, word in enumerate(word_list)])

    return context, common


Baber's avatar
nit  
Baber committed
55
56
57
58
59
60
61
def generate_input_output(
    num_words: int,
    max_seq_length: int,
    freq_cw: int = 30,
    freq_ucw: int = 3,
    num_cw: int = 10,
):
Baber's avatar
Baber committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if max_seq_length < 4096:
        context_example, answer_example = get_example(20, 3, 1, num_cw)
        context, answer = get_example(num_words, 6, 1, num_cw)
    else:
        context_example, answer_example = get_example(40, 10, 3, num_cw)
        context, answer = get_example(num_words, freq_cw, freq_ucw, num_cw)

    template = TEMPLATE

    input_example = template.format(
        context=context_example,
        query="",
    ) + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer_example)])

    input_text = template.format(
        context=context,
        query="",
    )

Baber's avatar
nit  
Baber committed
81
    return input_example, input_text, answer
Baber's avatar
Baber committed
82
83
84
85
86


def sys_word_pair_random(
    num_samples: int,
    max_seq_length: int,
Baber's avatar
add cwe  
Baber committed
87
    tokenizer=None,
Baber's avatar
Baber committed
88
89
    incremental: int = 10,
    remove_newline_tab=False,
Baber's avatar
Baber committed
90
    tokens_to_generate=120,
Baber's avatar
Baber committed
91
):
Baber's avatar
add cwe  
Baber committed
92
    assert tokenizer is not None, "Tokenizer is not provided."
Baber's avatar
Baber committed
93
94
95
96
97
98
99
100
    write_jsons = []
    tokens_to_generate = tokens_to_generate

    # Find the perfect num_words
    num_words = incremental

    total_tokens = 0
    while total_tokens + tokens_to_generate < max_seq_length:
Baber's avatar
nit  
Baber committed
101
102
103
        input_example, input_text, answer = generate_input_output(
            num_words, max_seq_length
        )
Baber's avatar
Baber committed
104
105
        # Calculate the number of tokens in the example
        total_tokens = len(
Baber's avatar
add cwe  
Baber committed
106
            tokenizer(
Baber's avatar
nit  
Baber committed
107
108
109
                input_example
                + "\n"
                + input_text
Baber's avatar
Baber committed
110
111
                + " "
                + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
Baber's avatar
add cwe  
Baber committed
112
            ).input_ids
Baber's avatar
Baber committed
113
        )
Baber's avatar
cleanup  
Baber committed
114
115
116
        # print(
        #     f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
        # )
Baber's avatar
Baber committed
117
118
119
120
121
122
123
124
125
        if total_tokens + tokens_to_generate > max_seq_length:
            num_words -= incremental
            break

        num_words += incremental
        if num_words > len(WORDS):
            num_words = len(WORDS)
            break

Baber's avatar
cleanup  
Baber committed
126
    # print("num_words:", num_words)
Baber's avatar
Baber committed
127
128

    # Generate samples
Baber's avatar
cleanup  
Baber committed
129
130
131
    for index in tqdm(
        range(num_samples), desc=f"Generating CWE Samples | {max_seq_length}"
    ):
Baber's avatar
Baber committed
132
133
134
        used_words = num_words
        while True:
            try:
Baber's avatar
nit  
Baber committed
135
136
137
                input_example, input_text, answer = generate_input_output(
                    used_words, max_seq_length
                )
Baber's avatar
add cwe  
Baber committed
138
                length = len(tokenizer(input_text).input_ids) + tokens_to_generate
Baber's avatar
Baber committed
139
140
                assert length <= max_seq_length, f"{length} exceeds max_seq_length."
                break
Baber's avatar
noqa  
Baber committed
141
            except:  # noqa: E722
Baber's avatar
Baber committed
142
143
144
145
146
147
148
                if used_words > incremental:
                    used_words -= incremental

        if remove_newline_tab:
            input_text = " ".join(
                input_text.replace("\n", " ").replace("\t", " ").strip().split()
            )
Baber's avatar
nit  
Baber committed
149
150
151
            input_example = " ".join(
                input_example.replace("\n", " ").replace("\t", " ").strip().split()
            )
Baber's avatar
Baber committed
152

Baber's avatar
fixup  
Baber committed
153
        gen_prefix_index = input_text.rfind(CONFIG["answer_prefix"])
Baber's avatar
add cwe  
Baber committed
154
        input_text = input_text[:gen_prefix_index]
Baber's avatar
Baber committed
155
156
        formatted_output = {
            "index": index,
Baber's avatar
fixup  
Baber committed
157
            "input": input_text.strip(),
Baber's avatar
add cwe  
Baber committed
158
            "input_example": input_example,
Baber's avatar
Baber committed
159
160
161
            "outputs": answer,
            "length": length,
            "max_length": max_seq_length,
Baber's avatar
fixup  
Baber committed
162
            "gen_prefix": CONFIG["answer_prefix"].strip(),
Baber's avatar
Baber committed
163
164
165
166
        }
        write_jsons.append(formatted_output)

    return write_jsons
Baber's avatar
add cwe  
Baber committed
167
168
169
170
171
172
173
174
175
176
177
178


def get_dataset(pretrained, seq=None, **kwargs):
    tokenizer = get_tokenizer(pretrained)
    write_jsons = sys_word_pair_random(
        num_samples=500, max_seq_length=seq, tokenizer=tokenizer
    )
    return write_jsons


def get_cw_dataset(**kwargs):
    pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
Baber's avatar
nit  
Baber committed
179
180
181
182
    df = (
        get_dataset(pretrained, seq=seq)
        for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
    )
Baber's avatar
add cwe  
Baber committed
183
184
185
186
187
188

    return {
        "test": datasets.Dataset.from_list(
            list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
        )
    }