cwe_utils.py 4.54 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License


import random

import wonderwords
from tqdm import tqdm


RNG = random.Random(42)
TEMPLATE = ""
r = wonderwords.RandomWord()
WORDS = sorted(
    list(
        set([item for x in ["noun", "adjective", "verb"] for item in r._categories[x]])
    )
)
RNG.shuffle(WORDS)


def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10):
    word_list_full = random.sample(WORDS, num_words)
    common, uncommon = word_list_full[:common_nums], word_list_full[common_nums:]
    word_list = common * int(common_repeats) + uncommon * int(uncommon_repeats)
    RNG.shuffle(word_list)

    # Formatting the word list as "1. word1 2. word2 3. word3 ..."
    context = " ".join([f"{i + 1}. {word}" for i, word in enumerate(word_list)])

    return context, common


Baber's avatar
nit  
Baber committed
45
46
47
48
49
50
51
def generate_input_output(
    num_words: int,
    max_seq_length: int,
    freq_cw: int = 30,
    freq_ucw: int = 3,
    num_cw: int = 10,
):
Baber's avatar
Baber committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    if max_seq_length < 4096:
        context_example, answer_example = get_example(20, 3, 1, num_cw)
        context, answer = get_example(num_words, 6, 1, num_cw)
    else:
        context_example, answer_example = get_example(40, 10, 3, num_cw)
        context, answer = get_example(num_words, freq_cw, freq_ucw, num_cw)

    template = TEMPLATE

    input_example = template.format(
        context=context_example,
        query="",
    ) + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer_example)])

    input_text = template.format(
        context=context,
        query="",
    )

Baber's avatar
nit  
Baber committed
71
    return input_example, input_text, answer
Baber's avatar
Baber committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90


def sys_word_pair_random(
    num_samples: int,
    max_seq_length: int,
    TOKENIZER=None,
    incremental: int = 10,
    remove_newline_tab=False,
    tokens_to_generate=120,
):
    assert TOKENIZER is not None, "Tokenizer is not provided."
    write_jsons = []
    tokens_to_generate = tokens_to_generate

    # Find the perfect num_words
    num_words = incremental

    total_tokens = 0
    while total_tokens + tokens_to_generate < max_seq_length:
Baber's avatar
nit  
Baber committed
91
92
93
        input_example, input_text, answer = generate_input_output(
            num_words, max_seq_length
        )
Baber's avatar
Baber committed
94
95
96
        # Calculate the number of tokens in the example
        total_tokens = len(
            TOKENIZER(
Baber's avatar
nit  
Baber committed
97
98
99
                input_example
                + "\n"
                + input_text
Baber's avatar
Baber committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
                + " "
                + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
            )
        )
        print(
            f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
        )
        if total_tokens + tokens_to_generate > max_seq_length:
            num_words -= incremental
            break

        num_words += incremental
        if num_words > len(WORDS):
            num_words = len(WORDS)
            break

    print("num_words:", num_words)

    # Generate samples
    for index in tqdm(range(num_samples)):
        used_words = num_words
        while True:
            try:
Baber's avatar
nit  
Baber committed
123
124
125
                input_example, input_text, answer = generate_input_output(
                    used_words, max_seq_length
                )
Baber's avatar
Baber committed
126
127
128
                length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
                assert length <= max_seq_length, f"{length} exceeds max_seq_length."
                break
Baber's avatar
noqa  
Baber committed
129
            except:  # noqa: E722
Baber's avatar
Baber committed
130
131
132
133
134
135
136
                if used_words > incremental:
                    used_words -= incremental

        if remove_newline_tab:
            input_text = " ".join(
                input_text.replace("\n", " ").replace("\t", " ").strip().split()
            )
Baber's avatar
nit  
Baber committed
137
138
139
            input_example = " ".join(
                input_example.replace("\n", " ").replace("\t", " ").strip().split()
            )
Baber's avatar
Baber committed
140
141
142
143
144
145
146
147
148
149
150

        formatted_output = {
            "index": index,
            "input": input_text,
            "outputs": answer,
            "length": length,
            "max_length": max_seq_length,
        }
        write_jsons.append(formatted_output)

    return write_jsons