text_generation_utils.py 21.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
30
from megatron.p2p_communication import recv_forward, send_forward
31

32
33
34
35
36
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

37
38
39
40
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
41

42
    # Move to GPU.
43
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
44
45
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
46
        tokens,
47
        tokenizer.eod,
48
        args.reset_position_ids,
49
        args.reset_attention_mask,
50
        args.eod_mask_loss)
51

52
53
    return tokens, attention_mask, position_ids

54

55
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
56
57
58
59
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
60
61

    if top_k > 0:
62
63
        # Remove all tokens with a probability less than the
        # last token of the top-k
64
65
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
66

67
    if top_p > 0.0:
68
69
70
71
72
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
73
74
75

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
76
77
78
79
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
80
        sorted_indices_to_remove[..., 0] = 0
81
82
83
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
84

85
86
87
    return logits


88
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
89

90
91
    args = get_args()
    tokenizer = get_tokenizer()
92

93
94
95
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
96
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
97
98
99
100
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
101
102
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
103
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
104
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
105
106
        else:
            sample_output_file = args.sample_output_file
107
        fname_out = open(sample_output_file, "w+")
108

Mohammad's avatar
Mohammad committed
109
    context_count = 0
110
111
112
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
113
            terminate_runs = 0
114
            raw_text_len = 0
115

116
117
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
118
119
120
121
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
122
                raw_text_len = len(raw_text)
123
124
125
126

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
127
                    context_tokens = tokenizer.tokenize(raw_text)
128
129
                    context_length = len(context_tokens)

130
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
131
132
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
133
                              "sequence length)!", flush=True)
134
135
                        continue
            else:
136
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
137
                context_length = 0
Mohammad's avatar
Mohammad committed
138

139
140
141
142
143
144
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
145
            context_length = input_info_tensor[2].item()
146
147
148
149

            if terminate_runs == 1:
                return

150
151
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
152
153
154
155
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
156
                    group = mpu.get_pipeline_model_parallel_group()
157
158
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
159
                else:
160
                    src = mpu.get_pipeline_model_parallel_first_rank()
161
                    group = mpu.get_pipeline_model_parallel_group()
162
163
164
165
166
167
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

168
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
169
            for _, decode_tokens in enumerate(token_stream):
170
                pass
171

172
            if mpu.get_tensor_model_parallel_rank() == 0:
173
174
175
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
176

177
178
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
179

180
181
182
183
184
185
186
187
188
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
189

190
            raw_text = None
191
            context_count += 1
Mohammad's avatar
Mohammad committed
192

193

194
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
195

196
197
    args = get_args()
    tokenizer = get_tokenizer()
198

Mohammad's avatar
Mohammad committed
199
    context_count = 0
200
201
202
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
203
            terminate_runs = 0
204
            raw_text_len = 0
205

206
207
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
208
                os.system('clear')
209
210
211
212
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
213
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
214

215
216
217
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
218
                    context_tokens = tokenizer.tokenize(raw_text)
219
220
                    context_length = len(context_tokens)

221
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
222
223
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
224
                              "sequence length)!", flush=True)
225
226
                        continue
            else:
227
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
228
                context_length = 0
Mohammad's avatar
Mohammad committed
229

230
231
232
233
234
235
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
236
            context_length = input_info_tensor[2].item()
237
238
239
240

            if terminate_runs == 1:
                return

241
242
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
243
244
245
246
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
247
                    group = mpu.get_pipeline_model_parallel_group()
248
249
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
250
                else:
251
                    src = mpu.get_pipeline_model_parallel_first_rank()
252
                    group = mpu.get_pipeline_model_parallel_group()
253
254
255
256
257
258
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

259
            token_stream = get_token_stream(model, [context_tokens])
260

261
            for counter, decode_tokens in enumerate(token_stream):
262
263
264
265
266
267
268
269
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

270
271
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
272
273
274
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
275

276
277
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
278
279
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
280
281
282
283

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
284
                trim_decode_tokens = tokenizer.detokenize(
285
                    decode_tokens)[raw_text_len:]
286
287
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

288
289
                input("\nPress Enter to continue >>>")

290
291
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
292

293

294
295

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
296

297
298
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
299

300
    num_samples = args.num_samples
301
    context_tokens = [[tokenizer.eod]
302
                      for _ in range(args.micro_batch_size)]
303
304
305
    ctr = 0
    while True:
        start_time = time.time()
306
307
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
308
            pass
309
310
311
312
313
314
315
316
317
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
318
            assert len(length_batch) == args.micro_batch_size
319
320
321
322
323
324
325
326
327
328
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
329
            for _ in range(args.micro_batch_size):
330
331
332
333
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
334
335
336
        if ctr >= num_samples:
            break

337

Mohammad's avatar
Mohammad committed
338
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
339

340
    args = get_args()
341
342
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
343
        for datum in generate_samples_unconditional(model):
344
345
346
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
347

348

Mohammad's avatar
Mohammad committed
349
350
def pad_batch(batch, pad_id, args):

351
352
353
354
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
355
            tokens.extend([pad_id] * (args.seq_length - context_length))
356
357
358
        context_lengths.append(context_length)
    return batch, context_lengths

359
360

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
361

362
363
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
364
365
366

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
367
368
369
370

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

371
    torch.distributed.broadcast(context_length_tensor,
372
373
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
374
    torch.distributed.broadcast(context_tokens_tensor,
375
376
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
377
378

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
379
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
380

381
382
383
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
384
385
    for tokens, lengths in batch_token_iterator:
        context_length += 1
386
387
388
389
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
390
391
392


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
393

394
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
395
    return (1 - boolean) * val1 + boolean * val2
396

397

398
399
400
401
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

Jared Casper's avatar
Jared Casper committed
402
403
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
404
405
406
407
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

Jared Casper's avatar
Jared Casper committed
408
    input_tensor = recv_forward()
409
410

    # Forward pass through the model.
411
412
413
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
Jared Casper's avatar
Jared Casper committed
414
415
416
417
418
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
419
420
421
422

    if get_key_value:
        output_tensor, layer_past = output_tensor

Jared Casper's avatar
Jared Casper committed
423
    send_forward(output_tensor)
424

425
    args.seq_length = orig_seq_length
426
427
428
429
430
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


431
432
433
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
434

435
436
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
437

438
439
440
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
Mohammad's avatar
Mohammad committed
441
        eos_id = tokenizer.eod
442
443
444
445
446
447
448
449
450
451
452
453
454

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
455
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
456

457
458
        while context_length <= (maxlen):
            if args.recompute:
459
460
461
462
463
464
465
466
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
467
            else:
468
                types2use = None
469
470
471
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
472
473
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
474
                else:
475
476
477
478
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
479
                    if type_ids is not None:
480
481
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
482
                output, layer_past = forward_step(model, tokens2use,
483
484
485
486
487
488
489
490
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
491
                    logits = output[:, -1].view(batch_size, -1).contiguous()
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

524
            else:
525
526
527
528
529
530
531
532
533
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
534

535
536
537
538
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
539

540
541
            context_length += 1
            counter += 1
542
543
            if done:
                break