"sgl-kernel/include/utils.h" did not exist on "b3251e9f40b85159d52563b9ca8276fa0fa03703"
text_generation_utils.py 15.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
29
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids
30

31
32
33
34
35

def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
36

37
    # Move to GPU.
Mohammad's avatar
Mohammad committed
38
    tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
39
40
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
41
        tokens,
42
        tokenizer.eod,
43
        args.reset_position_ids,
44
        args.reset_attention_mask,
45
46
        args.eod_mask_loss,
        args.fp16)
47

48
49
    return tokens, attention_mask, position_ids

50

51
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
52
53
54
55
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
56
57

    if top_k > 0:
58
59
        # Remove all tokens with a probability less than the
        # last token of the top-k
60
61
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
62

63
    if top_p > 0.0:
64
65
66
67
68
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
69
70
71

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
72
73
74
75
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
76
        sorted_indices_to_remove[..., 0] = 0
77
78
79
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
80

81
82
83
    return logits


84
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
85

86
87
    args = get_args()
    tokenizer = get_tokenizer()
88

89
90
91
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
92
93
94
95
96
    if mpu.get_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
97
98
99
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('could not find `sample-output-file`, setting '
Mohammad's avatar
Mohammad committed
100
                  'it to {}'.format(sample_output_file))
101
        fname_out = open(sample_output_file, "w+")
102

Mohammad's avatar
Mohammad committed
103
    context_count = 0
104
105
106
107
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
108
            terminate_runs = 0
109
110

            if mpu.get_model_parallel_rank() == 0:
111
112
113
114
115
116
117
118
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
119
                    context_tokens = tokenizer.tokenize(raw_text)
120
121
                    context_length = len(context_tokens)

122
                    if context_length >= (args.seq_length // 2):
123
                        print("\nContext length", context_length, \
124
125
                            "\nPlease give smaller context (half of the "
                              "sequence length)!", flush=True)
126
127
                        continue
            else:
128
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
129
                context_length = len(context_tokens)
Mohammad's avatar
Mohammad committed
130

131
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
132
133
134
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
135
136
137
138
139
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

140
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
141
            for _, decode_tokens in enumerate(token_stream):
142
143
144
145
146
147
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
148
149
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
150
151
152
153
154
155
156
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                fname_out.write("\nContext:")
                fname_out.write(raw_text)
                fname_out.write("\n\nMegatron-LM:")
                fname_out.write(trim_decode_tokens)
                fname_out.write("\n")
Mohammad's avatar
Mohammad committed
157

158
159
160
161
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Mohammad's avatar
Mohammad committed
162

163

164
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
165

166
167
    args = get_args()
    tokenizer = get_tokenizer()
168

Mohammad's avatar
Mohammad committed
169
    context_count = 0
170
171
172
173
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Mohammad's avatar
Mohammad committed
174
            terminate_runs = 0
175
176
177

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
178
179
180
181
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
Mohammad's avatar
Mohammad committed
182

183
184
185
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
186
                    context_tokens = tokenizer.tokenize(raw_text)
187
188
                    context_length = len(context_tokens)

189
                    if context_length >= (args.seq_length // 2):
190
                        print("\nContext length", context_length, \
191
192
                            "\nPlease give smaller context (half of the "
                              "sequence length)!", flush=True)
193
194
                        continue
            else:
195
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
196
                context_length = len(context_tokens)
Mohammad's avatar
Mohammad committed
197

198
            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
199
200
201
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
202
203
204
205
206
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

207
            token_stream = get_token_stream(model, [context_tokens])
208
209
210
211
            for counter, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

212
213
                if mpu.get_model_parallel_rank() == 0 and \
                   counter % print_frequency == 0:
214
215
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
216
217
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[len(raw_text):]
218
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)
219
220
221
222

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
223
224
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[len(raw_text):]
225
226
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

227
228
229
            raw_text = None
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            context_count += 1
Mohammad's avatar
Mohammad committed
230

231
232
233
            if mpu.get_model_parallel_rank() == 0:
                input("\nPress any key to continue >>>")

234
235

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
236

237
238
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
239

240
    num_samples = args.num_samples
241
242
    context_tokens = [[tokenizer.eod]
                      for _ in range(args.batch_size)]
243
244
245
    ctr = 0
    while True:
        start_time = time.time()
246
247
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
248
249
            pass
        if ctr%args.log_interval == 0:
250
251
            print('Avg s/batch:',
                  (time.time() - start_time) / min(args.log_interval, ctr + 1))
252
253
254
255
256
257
            start_time = time.time()
        length = len(token_stream)
        token_batch = token_stream[0].cpu().numpy().tolist()
        length_batch = token_stream[1].cpu().numpy().tolist()
        for tokens, length in zip(token_batch, length_batch):
            tokens = tokens[1:length-1]
258
            text = tokenizer.detokenize(tokens)
259
260
261
262
263
264
265
266
267
            is_finished = length < args.seq_length - 1
            datum = {'text': text, 'length': length-1, 'finished': is_finished}
            yield datum
            ctr += 1
            if ctr >= num_samples:
                break
        if ctr >= num_samples:
            break

268

Mohammad's avatar
Mohammad committed
269
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
270

271
    args = get_args()
272
273
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
274
        for datum in generate_samples_unconditional(model):
275
276
            f.write(json.dumps(datum)+'\n')

277

Mohammad's avatar
Mohammad committed
278
279
def pad_batch(batch, pad_id, args):

280
281
282
283
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
284
            tokens.extend([pad_id]*(args.seq_length - context_length))
285
286
287
        context_lengths.append(context_length)
    return batch, context_lengths

288
289

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
290

291
292
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
293
294
295

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
296
297
298
299

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

300
301
302
303
304
305
    torch.distributed.broadcast(context_length_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    torch.distributed.broadcast(context_tokens_tensor,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
306
307

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
308
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
309

310
311
312
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
313
314
315
316
317
318
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        yield tokens[:, :context_length], lengths


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
319

320
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
321
    return (1 - boolean) * val1 + boolean * val2
322

323
324
325
326

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
327

328
329
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
330

331
332
333
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
Mohammad's avatar
Mohammad committed
334
        eos_id = tokenizer.eod
335
336
337
338
339
340
341
342
343
344
345
346
347
348

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

        lengths = torch.ones([batch_size]).long().cuda()*maxlen
Mohammad's avatar
Mohammad committed
349

350
351
352
        while context_length <= (maxlen):

            if args.recompute:
353
354
355
356
357
                logits = model(tokens,
                               position_ids,
                               attention_mask,
                               tokentype_ids=type_ids,
                               forward_method_parallel_output=False)
358
                logits = logits[:, context_length - 1, :]
359
            else:
360
                types2use = None
361
362
363
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
364
365
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
366
                else:
367
368
369
370
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
371
                    if type_ids is not None:
372
373
374
375
376
377
378
379
380
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
                logits, layer_past = model(tokens2use,
                                           positions2use,
                                           attention_mask,
                                           layer_past=layer_past,
                                           get_key_value=True,
                                           tokentype_ids=types2use,
                                           forward_method_parallel_output=False)
Mohammad's avatar
Mohammad committed
381
                logits = logits[:, -1].view(batch_size, -1).contiguous()
382
383
384
385

            if args.greedy:
                prev = torch.argmax(logits, dim=-1).view(-1)
            else:
Raul Puri's avatar
Raul Puri committed
386
                logits = logits.float()
387
                logits /= args.temperature
388
389
                logits = top_k_logits(logits, top_k=args.top_k,
                                      top_p=args.top_p)
390
391
392
393
394
                log_probs = F.softmax(logits, dim=-1)
                prev = torch.multinomial(log_probs, num_samples=1).view(-1)

            print_logits = []
            for p in prev:
395
396
                print_logits.append([logits[i, p].item()
                                     for i in range(batch_size)])
397
            started = context_lengths <= context_length
398
399
            tokens[:, context_length] = switch(
                tokens[:, context_length].view(-1), prev, started)
400
401
402
            context_length += 1
            counter += 1

403
            done_token = (prev == eos_id).byte() & started.byte()
404
405
406
407
408
409
410
411
            just_finished = (done_token & ~is_done).bool()
            lengths[just_finished.view(-1)] = context_length
            is_done = is_done | done_token
            done = torch.all(is_done)

            yield tokens, lengths
            if done:
                break