"router/src/infer/v3/queue.rs" did not exist on "5e6ddfd6a4fecc394255d7109f87c420c98b4e15"
text_generation_utils.py 22.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron.training import communicate
Mohammad's avatar
Mohammad committed
30
from megatron.utils import get_ltor_masks_and_position_ids
31

32
33
34
35
36

def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
37

38
    # Move to GPU.
39
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
40
41
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
42
        tokens,
43
        tokenizer.eod,
44
        args.reset_position_ids,
45
        args.reset_attention_mask,
46
        args.eod_mask_loss)
47

48
49
    return tokens, attention_mask, position_ids

50

51
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
52
53
54
55
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
56
57

    if top_k > 0:
58
59
        # Remove all tokens with a probability less than the
        # last token of the top-k
60
61
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
62

63
    if top_p > 0.0:
64
65
66
67
68
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
69
70
71

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
72
73
74
75
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
76
        sorted_indices_to_remove[..., 0] = 0
77
78
79
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
80

81
82
83
    return logits


84
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
85

86
87
    args = get_args()
    tokenizer = get_tokenizer()
88

89
90
91
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
92
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
93
94
95
96
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
97
98
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
99
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
100
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
101
102
        else:
            sample_output_file = args.sample_output_file
103
        fname_out = open(sample_output_file, "w+")
104

Mohammad's avatar
Mohammad committed
105
    context_count = 0
106
107
108
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
109
            terminate_runs = 0
110
            raw_text_len = 0
111

112
113
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
114
115
116
117
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
118
                raw_text_len = len(raw_text)
119
120
121
122

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
123
                    context_tokens = tokenizer.tokenize(raw_text)
124
125
                    context_length = len(context_tokens)

126
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
127
128
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
129
                              "sequence length)!", flush=True)
130
131
                        continue
            else:
132
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
133
                context_length = 0
Mohammad's avatar
Mohammad committed
134

135
136
137
138
139
140
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
141
142
143
144

            if terminate_runs == 1:
                return

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            # For pipeline parallel we send context tokens to last stage
            # so it knows when to start overwriting
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
                    group = mpu.get_embedding_group()
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                if mpu.is_pipeline_last_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
                    group = mpu.get_embedding_group()
                    context_length = input_info_tensor[2].item()
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

164
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
165
            for _, decode_tokens in enumerate(token_stream):
166
                pass
167

168
            if mpu.get_tensor_model_parallel_rank() == 0:
169
170
171
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
172

173
174
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
175

176
177
178
179
180
181
182
183
184
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
185

186
            raw_text = None
187
            context_count += 1
Mohammad's avatar
Mohammad committed
188

189

190
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
191

192
193
    args = get_args()
    tokenizer = get_tokenizer()
194

Mohammad's avatar
Mohammad committed
195
    context_count = 0
196
197
198
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
199
            terminate_runs = 0
200
            raw_text_len = 0
201

202
203
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
204
                os.system('clear')
205
206
207
208
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
209
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
210

211
212
213
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
214
                    context_tokens = tokenizer.tokenize(raw_text)
215
216
                    context_length = len(context_tokens)

217
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
218
219
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
220
                              "sequence length)!", flush=True)
221
222
                        continue
            else:
223
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
224
                context_length = 0
Mohammad's avatar
Mohammad committed
225

226
227
228
229
230
231
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
232
233
234
235

            if terminate_runs == 1:
                return

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            # For pipeline parallel we send context tokens to last stage
            # so it knows when to start overwriting
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
                    group = mpu.get_embedding_group()
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                if mpu.is_pipeline_last_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
                    group = mpu.get_embedding_group()
                    context_length = input_info_tensor[2].item()
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

255
            token_stream = get_token_stream(model, [context_tokens])
256
            for counter, decode_tokens in enumerate(token_stream):
257
258
259
260
261
262
263
264
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

265
266
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
267
268
269
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
270

271
272
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
273
274
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
275
276
277
278

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
279
                trim_decode_tokens = tokenizer.detokenize(
280
                    decode_tokens)[raw_text_len:]
281
282
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

283
284
                input("\nPress Enter to continue >>>")

285
286
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
287

288

289
290

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
291

292
293
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
294

295
    num_samples = args.num_samples
296
    context_tokens = [[tokenizer.eod]
297
                      for _ in range(args.micro_batch_size)]
298
299
300
    ctr = 0
    while True:
        start_time = time.time()
301
302
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
303
            pass
304
305
306
307
308
309
310
311
312
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
313
            assert len(length_batch) == args.micro_batch_size
314
315
316
317
318
319
320
321
322
323
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
324
            for _ in range(args.micro_batch_size):
325
326
327
328
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
329
330
331
        if ctr >= num_samples:
            break

332

Mohammad's avatar
Mohammad committed
333
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
334

335
    args = get_args()
336
337
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
338
        for datum in generate_samples_unconditional(model):
339
340
341
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
342

343

Mohammad's avatar
Mohammad committed
344
345
def pad_batch(batch, pad_id, args):

346
347
348
349
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
350
            tokens.extend([pad_id] * (args.seq_length - context_length))
351
352
353
        context_lengths.append(context_length)
    return batch, context_lengths

354
355

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
356

357
358
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
359
360
361

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
362
363
364
365

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

366
    torch.distributed.broadcast(context_length_tensor,
367
368
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
369
    torch.distributed.broadcast(context_tokens_tensor,
370
371
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
372
373

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
374
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
375

376
377
378
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
379
380
    for tokens, lengths in batch_token_iterator:
        context_length += 1
381
382
383
384
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
385
386
387


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
388

389
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
390
    return (1 - boolean) * val1 + boolean * val2
391

392

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

    if not mpu.is_pipeline_first_stage():
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
    else:
        input_tensor = None

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output_tensor = model(tokens, position_ids, attention_mask,
                                  tokentype_ids=tokentype_ids,
                                  layer_past=layer_past,
                                  get_key_value=get_key_value,
                                  forward_method_parallel_output=forward_method_parallel_output)
        else:
            output_tensor = model(tokens, position_ids, attention_mask,
                                  tokentype_ids=tokentype_ids,
                                  layer_past=layer_past,
                                  get_key_value=get_key_value)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask,
                              layer_past=layer_past,
                              get_key_value=get_key_value,
                              forward_method_parallel_output=forward_method_parallel_output)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask,
                              layer_past=layer_past,
                              get_key_value=get_key_value)

    if get_key_value:
        output_tensor, layer_past = output_tensor

    if not mpu.is_pipeline_last_stage():
        communicate(tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)
        return None

    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


447
448
449
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
450

451
452
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
453

454
455
456
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
Mohammad's avatar
Mohammad committed
457
        eos_id = tokenizer.eod
458
459
460
461
462
463
464
465
466
467
468
469
470

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
471
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
472

473
474
        while context_length <= (maxlen):
            if args.recompute:
475
476
477
478
479
480
481
482
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
483
            else:
484
                types2use = None
485
486
487
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
488
489
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
490
                else:
491
492
493
494
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
495
                    if type_ids is not None:
496
497
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
                logits, layer_past = forward_step(model, tokens2use,
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = logits[:, -1].view(batch_size, -1).contiguous()

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

540
            else:
541
542
543
544
545
546
547
548
549
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
550

551
552
553
554
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
555

556
557
            context_length += 1
            counter += 1
558
559
            if done:
                break