text_generation_utils.py 22.9 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
30
from megatron.p2p_communication import recv_forward, send_forward
31

32
33
34
35
36
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

37
38
39
40
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
41

42
    # Move to GPU.
43
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
44
45
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
46
        tokens,
47
        tokenizer.eod,
48
        args.reset_position_ids,
49
        args.reset_attention_mask,
50
        args.eod_mask_loss)
51

52
53
    return tokens, attention_mask, position_ids

54

55
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
56
57
58
59
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
60
61

    if top_k > 0:
62
63
        # Remove all tokens with a probability less than the
        # last token of the top-k
64
65
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
66

67
    if top_p > 0.0:
68
69
70
71
72
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
73
74
75

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
76
77
78
79
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
80
        sorted_indices_to_remove[..., 0] = 0
81
82
83
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
84

85
86
87
    return logits


88
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
89

90
91
    args = get_args()
    tokenizer = get_tokenizer()
92

93
94
95
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
96
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
97
98
99
100
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
101
102
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
103
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
104
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
105
106
        else:
            sample_output_file = args.sample_output_file
107
        fname_out = open(sample_output_file, "w+")
108

Mohammad's avatar
Mohammad committed
109
    context_count = 0
110
111
112
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
113
            terminate_runs = 0
114
            raw_text_len = 0
115

116
117
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
118
119
120
121
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
122
                raw_text_len = len(raw_text)
123
124
125
126

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
127
                    context_tokens = tokenizer.tokenize(raw_text)
128
129
                    context_length = len(context_tokens)

130
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
131
132
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
133
                              "sequence length)!", flush=True)
134
135
                        continue
            else:
136
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
137
                context_length = 0
Mohammad's avatar
Mohammad committed
138

139
140
141
142
143
144
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
145
            context_length = input_info_tensor[2].item()
146
147
148
149

            if terminate_runs == 1:
                return

150
151
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
152
153
154
155
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
156
                    group = mpu.get_pipeline_model_parallel_group()
157
158
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
159
                else:
160
                    src = mpu.get_pipeline_model_parallel_first_rank()
161
                    group = mpu.get_pipeline_model_parallel_group()
162
163
164
165
166
167
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

168
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
169
            for _, decode_tokens in enumerate(token_stream):
170
                pass
171

172
            if mpu.get_tensor_model_parallel_rank() == 0:
173
174
175
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
176

177
178
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
179

180
181
182
183
184
185
186
187
188
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
189

190
            raw_text = None
191
            context_count += 1
Mohammad's avatar
Mohammad committed
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sample):
    # Generate samples for lm evaluation
    # NEED TO THINK ABOUT eos token

    args = get_args()
    tokenizer = get_tokenizer()

    raw_text_len = len(context)
    model.eval()

    context_tokens = tokenizer.tokenize(context)
    args.out_seq_length = max_gen_length + len(context_tokens)
    args.recompute = True #set this default value
    args.eos_id = eos_token_id

    if not do_sample:
        args.greedy = True
    else:
        # set similar to huggngface
        args.top_p = 1.0
        args.temperature = 1.0
        args.top_k = 50

    with torch.no_grad():
        token_stream = get_token_stream(model, [context_tokens])
        for counter, decode_tokens in enumerate(token_stream):
            decode_tokens, _ = decode_tokens
            decode_tokens = decode_tokens[0].cpu().numpy().tolist()
            trim_decode_tokens = tokenizer.detokenize(
                decode_tokens)[raw_text_len:]
            if counter == args.out_seq_length:
                break

    return trim_decode_tokens

228

229
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
230

231
232
    args = get_args()
    tokenizer = get_tokenizer()
233

Mohammad's avatar
Mohammad committed
234
    context_count = 0
235
236
237
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
238
            terminate_runs = 0
239
            raw_text_len = 0
240

241
242
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
243
                os.system('clear')
244
245
246
247
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
248
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
249

250
251
252
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
253
                    context_tokens = tokenizer.tokenize(raw_text)
254
255
                    context_length = len(context_tokens)

256
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
257
258
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
259
                              "sequence length)!", flush=True)
260
261
                        continue
            else:
262
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
263
                context_length = 0
Mohammad's avatar
Mohammad committed
264

265
266
267
268
269
270
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
271
            context_length = input_info_tensor[2].item()
272
273
274
275

            if terminate_runs == 1:
                return

276
277
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
278
279
280
281
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
282
                    group = mpu.get_pipeline_model_parallel_group()
283
284
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
285
                else:
286
                    src = mpu.get_pipeline_model_parallel_first_rank()
287
                    group = mpu.get_pipeline_model_parallel_group()
288
289
290
291
292
293
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

294
            token_stream = get_token_stream(model, [context_tokens])
295

296
            for counter, decode_tokens in enumerate(token_stream):
297
298
299
300
301
302
303
304
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

305
306
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
307
308
309
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
310

311
312
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
313
314
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
315
316
317
318

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
319
                trim_decode_tokens = tokenizer.detokenize(
320
                    decode_tokens)[raw_text_len:]
321
322
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

323
324
                input("\nPress Enter to continue >>>")

325
326
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
327

328

329
330

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
331

332
333
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
334

335
    num_samples = args.num_samples
336
    context_tokens = [[tokenizer.eod]
337
                      for _ in range(args.micro_batch_size)]
338
339
340
    ctr = 0
    while True:
        start_time = time.time()
341
342
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
343
            pass
344
345
346
347
348
349
350
351
352
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
353
            assert len(length_batch) == args.micro_batch_size
354
355
356
357
358
359
360
361
362
363
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
364
            for _ in range(args.micro_batch_size):
365
366
367
368
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
369
370
371
        if ctr >= num_samples:
            break

372

Mohammad's avatar
Mohammad committed
373
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
374

375
    args = get_args()
376
377
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
378
        for datum in generate_samples_unconditional(model):
379
380
381
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
382

383

Mohammad's avatar
Mohammad committed
384
385
def pad_batch(batch, pad_id, args):

386
387
388
389
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
390
            tokens.extend([pad_id] * (args.seq_length - context_length))
391
392
393
        context_lengths.append(context_length)
    return batch, context_lengths

394
395

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
396

397
398
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
399
400
401

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
402
403
404
405

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

406
    torch.distributed.broadcast(context_length_tensor,
407
408
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
409
    torch.distributed.broadcast(context_tokens_tensor,
410
411
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
412
413

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
414
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
415

416
417
418
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
419
420
    for tokens, lengths in batch_token_iterator:
        context_length += 1
421
422
423
424
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
425
426
427


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
428

429
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
430
    return (1 - boolean) * val1 + boolean * val2
431

432

433
434
435
436
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

Jared Casper's avatar
Jared Casper committed
437
438
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
439
440
441
442
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

Jared Casper's avatar
Jared Casper committed
443
    input_tensor = recv_forward()
444
445

    # Forward pass through the model.
446
447
448
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
Jared Casper's avatar
Jared Casper committed
449
450
451
452
453
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
454
455
456
457

    if get_key_value:
        output_tensor, layer_past = output_tensor

Jared Casper's avatar
Jared Casper committed
458
    send_forward(output_tensor)
459

460
    args.seq_length = orig_seq_length
461
462
463
464
465
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


466
467
468
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
469

470
471
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
472

473
474
475
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
476
477
478
479
480

        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod
481
482
483
484
485
486
487
488
489
490
491
492
493

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
494
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
495

496
497
        while context_length <= (maxlen):
            if args.recompute:
498
499
500
501
502
503
504
505
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
506
            else:
507
                types2use = None
508
509
510
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
511
512
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
513
                else:
514
515
516
517
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
518
                    if type_ids is not None:
519
520
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
521
                output, layer_past = forward_step(model, tokens2use,
522
523
524
525
526
527
528
529
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
530
                    logits = output[:, -1].view(batch_size, -1).contiguous()
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

563
            else:
564
565
566
567
568
569
570
571
572
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
573

574
575
576
577
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
578

579
580
            context_length += 1
            counter += 1
581
582
            if done:
                break