text_generation_utils.py 22.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron.training import communicate
Mohammad's avatar
Mohammad committed
30
from megatron.utils import get_ltor_masks_and_position_ids
31

32
33
34
35
36

def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
37

38
    # Move to GPU.
39
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
40
41
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
42
        tokens,
43
        tokenizer.eod,
44
        args.reset_position_ids,
45
        args.reset_attention_mask,
46
        args.eod_mask_loss)
47

48
49
    return tokens, attention_mask, position_ids

50

51
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
52
53
54
55
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
56
57

    if top_k > 0:
58
59
        # Remove all tokens with a probability less than the
        # last token of the top-k
60
61
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
62

63
    if top_p > 0.0:
64
65
66
67
68
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
69
70
71

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
72
73
74
75
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
76
        sorted_indices_to_remove[..., 0] = 0
77
78
79
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
80

81
82
83
    return logits


84
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
85

86
87
    args = get_args()
    tokenizer = get_tokenizer()
88

89
90
91
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
92
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
93
94
95
96
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
97
98
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
99
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
100
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
101
102
        else:
            sample_output_file = args.sample_output_file
103
        fname_out = open(sample_output_file, "w+")
104

Mohammad's avatar
Mohammad committed
105
    context_count = 0
106
107
108
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
109
            terminate_runs = 0
110
            raw_text_len = 0
111

112
113
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
114
115
116
117
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
118
                raw_text_len = len(raw_text)
119
120
121
122

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
123
                    context_tokens = tokenizer.tokenize(raw_text)
124
125
                    context_length = len(context_tokens)

126
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
127
128
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
129
                              "sequence length)!", flush=True)
130
131
                        continue
            else:
132
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
133
                context_length = 0
Mohammad's avatar
Mohammad committed
134

135
136
137
138
139
140
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
141
            context_length = input_info_tensor[2].item()
142
143
144
145

            if terminate_runs == 1:
                return

146
147
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
148
149
150
151
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
152
                    group = mpu.get_pipeline_model_parallel_group()
153
154
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
155
                else:
156
                    src = mpu.get_pipeline_model_parallel_first_rank()
157
                    group = mpu.get_pipeline_model_parallel_group()
158
159
160
161
162
163
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

164
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
165
            for _, decode_tokens in enumerate(token_stream):
166
                pass
167

168
            if mpu.get_tensor_model_parallel_rank() == 0:
169
170
171
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
172

173
174
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
175

176
177
178
179
180
181
182
183
184
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
185

186
            raw_text = None
187
            context_count += 1
Mohammad's avatar
Mohammad committed
188

189

190
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
191

192
193
    args = get_args()
    tokenizer = get_tokenizer()
194

Mohammad's avatar
Mohammad committed
195
    context_count = 0
196
197
198
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
199
            terminate_runs = 0
200
            raw_text_len = 0
201

202
203
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
204
                os.system('clear')
205
206
207
208
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
209
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
210

211
212
213
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
214
                    context_tokens = tokenizer.tokenize(raw_text)
215
216
                    context_length = len(context_tokens)

217
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
218
219
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
220
                              "sequence length)!", flush=True)
221
222
                        continue
            else:
223
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
224
                context_length = 0
Mohammad's avatar
Mohammad committed
225

226
227
228
229
230
231
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
232
            context_length = input_info_tensor[2].item()
233
234
235
236

            if terminate_runs == 1:
                return

237
238
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
239
240
241
242
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
243
                    group = mpu.get_pipeline_model_parallel_group()
244
245
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
246
                else:
247
                    src = mpu.get_pipeline_model_parallel_first_rank()
248
                    group = mpu.get_pipeline_model_parallel_group()
249
250
251
252
253
254
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

255
            token_stream = get_token_stream(model, [context_tokens])
256

257
            for counter, decode_tokens in enumerate(token_stream):
258
259
260
261
262
263
264
265
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

266
267
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
268
269
270
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
271

272
273
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
274
275
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
276
277
278
279

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
280
                trim_decode_tokens = tokenizer.detokenize(
281
                    decode_tokens)[raw_text_len:]
282
283
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

284
285
                input("\nPress Enter to continue >>>")

286
287
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
288

289

290
291

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
292

293
294
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
295

296
    num_samples = args.num_samples
297
    context_tokens = [[tokenizer.eod]
298
                      for _ in range(args.micro_batch_size)]
299
300
301
    ctr = 0
    while True:
        start_time = time.time()
302
303
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
304
            pass
305
306
307
308
309
310
311
312
313
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
314
            assert len(length_batch) == args.micro_batch_size
315
316
317
318
319
320
321
322
323
324
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
325
            for _ in range(args.micro_batch_size):
326
327
328
329
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
330
331
332
        if ctr >= num_samples:
            break

333

Mohammad's avatar
Mohammad committed
334
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
335

336
    args = get_args()
337
338
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
339
        for datum in generate_samples_unconditional(model):
340
341
342
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
343

344

Mohammad's avatar
Mohammad committed
345
346
def pad_batch(batch, pad_id, args):

347
348
349
350
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
351
            tokens.extend([pad_id] * (args.seq_length - context_length))
352
353
354
        context_lengths.append(context_length)
    return batch, context_lengths

355
356

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
357

358
359
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
360
361
362

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
363
364
365
366

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

367
    torch.distributed.broadcast(context_length_tensor,
368
369
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
370
    torch.distributed.broadcast(context_tokens_tensor,
371
372
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
373
374

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
375
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
376

377
378
379
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
380
381
    for tokens, lengths in batch_token_iterator:
        context_length += 1
382
383
384
385
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
386
387
388


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
389

390
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
391
    return (1 - boolean) * val1 + boolean * val2
392

393

394
395
396
397
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

398
399
400
401
402
403
    # Hidden size changes when not using recompute, need to tell communicate()
    # the correct size
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    if not mpu.is_pipeline_first_stage():
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
    else:
        input_tensor = None

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output_tensor = model(tokens, position_ids, attention_mask,
                                  tokentype_ids=tokentype_ids,
                                  layer_past=layer_past,
                                  get_key_value=get_key_value,
                                  forward_method_parallel_output=forward_method_parallel_output)
        else:
            output_tensor = model(tokens, position_ids, attention_mask,
                                  tokentype_ids=tokentype_ids,
                                  layer_past=layer_past,
                                  get_key_value=get_key_value)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask,
                              layer_past=layer_past,
                              get_key_value=get_key_value,
                              forward_method_parallel_output=forward_method_parallel_output)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask,
                              layer_past=layer_past,
                              get_key_value=get_key_value)

    if get_key_value:
        output_tensor, layer_past = output_tensor

    if not mpu.is_pipeline_last_stage():
        communicate(tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)

448
    args.seq_length = orig_seq_length
449
450
451
452
453
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


454
455
456
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
457

458
459
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
460

461
462
463
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
Mohammad's avatar
Mohammad committed
464
        eos_id = tokenizer.eod
465
466
467
468
469
470
471
472
473
474
475
476
477

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
478
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
479

480
481
        while context_length <= (maxlen):
            if args.recompute:
482
483
484
485
486
487
488
489
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
490
            else:
491
                types2use = None
492
493
494
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
495
496
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
497
                else:
498
499
500
501
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
502
                    if type_ids is not None:
503
504
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
505
                output, layer_past = forward_step(model, tokens2use,
506
507
508
509
510
511
512
513
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
514
                    logits = output[:, -1].view(batch_size, -1).contiguous()
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

547
            else:
548
549
550
551
552
553
554
555
556
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
557

558
559
560
561
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
562

563
564
            context_length += 1
            counter += 1
565
566
            if done:
                break