text_generation_utils.py 21.4 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
29
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids
Jared Casper's avatar
Jared Casper committed
30
from megatron.p2p_communication import recv_forward, send_forward
31
32
33
34
35

def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
36

37
    # Move to GPU.
38
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
39
40
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
41
        tokens,
42
        tokenizer.eod,
43
        args.reset_position_ids,
44
        args.reset_attention_mask,
45
        args.eod_mask_loss)
46

47
48
    return tokens, attention_mask, position_ids

49

50
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
51
52
53
54
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
55
56

    if top_k > 0:
57
58
        # Remove all tokens with a probability less than the
        # last token of the top-k
59
60
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
61

62
    if top_p > 0.0:
63
64
65
66
67
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
68
69
70

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
71
72
73
74
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
75
        sorted_indices_to_remove[..., 0] = 0
76
77
78
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
79

80
81
82
    return logits


83
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
84

85
86
    args = get_args()
    tokenizer = get_tokenizer()
87

88
89
90
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
91
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
92
93
94
95
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
96
97
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
98
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
99
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
100
101
        else:
            sample_output_file = args.sample_output_file
102
        fname_out = open(sample_output_file, "w+")
103

Mohammad's avatar
Mohammad committed
104
    context_count = 0
105
106
107
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
108
            terminate_runs = 0
109
            raw_text_len = 0
110

111
112
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
113
114
115
116
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
117
                raw_text_len = len(raw_text)
118
119
120
121

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
122
                    context_tokens = tokenizer.tokenize(raw_text)
123
124
                    context_length = len(context_tokens)

125
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
126
127
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
128
                              "sequence length)!", flush=True)
129
130
                        continue
            else:
131
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
132
                context_length = 0
Mohammad's avatar
Mohammad committed
133

134
135
136
137
138
139
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
140
            context_length = input_info_tensor[2].item()
141
142
143
144

            if terminate_runs == 1:
                return

145
146
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
147
148
149
150
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
151
                    group = mpu.get_pipeline_model_parallel_group()
152
153
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
154
                else:
155
                    src = mpu.get_pipeline_model_parallel_first_rank()
156
                    group = mpu.get_pipeline_model_parallel_group()
157
158
159
160
161
162
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

163
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
164
            for _, decode_tokens in enumerate(token_stream):
165
                pass
166

167
            if mpu.get_tensor_model_parallel_rank() == 0:
168
169
170
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
171

172
173
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
174

175
176
177
178
179
180
181
182
183
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
184

185
            raw_text = None
186
            context_count += 1
Mohammad's avatar
Mohammad committed
187

188

189
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
190

191
192
    args = get_args()
    tokenizer = get_tokenizer()
193

Mohammad's avatar
Mohammad committed
194
    context_count = 0
195
196
197
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
198
            terminate_runs = 0
199
            raw_text_len = 0
200

201
202
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
203
                os.system('clear')
204
205
206
207
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
208
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
209

210
211
212
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
213
                    context_tokens = tokenizer.tokenize(raw_text)
214
215
                    context_length = len(context_tokens)

216
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
217
218
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
219
                              "sequence length)!", flush=True)
220
221
                        continue
            else:
222
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
223
                context_length = 0
Mohammad's avatar
Mohammad committed
224

225
226
227
228
229
230
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
231
            context_length = input_info_tensor[2].item()
232
233
234
235

            if terminate_runs == 1:
                return

236
237
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
238
239
240
241
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
242
                    group = mpu.get_pipeline_model_parallel_group()
243
244
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
245
                else:
246
                    src = mpu.get_pipeline_model_parallel_first_rank()
247
                    group = mpu.get_pipeline_model_parallel_group()
248
249
250
251
252
253
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

254
            token_stream = get_token_stream(model, [context_tokens])
255

256
            for counter, decode_tokens in enumerate(token_stream):
257
258
259
260
261
262
263
264
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

265
266
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
267
268
269
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
270

271
272
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
273
274
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
275
276
277
278

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
279
                trim_decode_tokens = tokenizer.detokenize(
280
                    decode_tokens)[raw_text_len:]
281
282
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

283
284
                input("\nPress Enter to continue >>>")

285
286
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
287

288

289
290

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
291

292
293
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
294

295
    num_samples = args.num_samples
296
    context_tokens = [[tokenizer.eod]
297
                      for _ in range(args.micro_batch_size)]
298
299
300
    ctr = 0
    while True:
        start_time = time.time()
301
302
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
303
            pass
304
305
306
307
308
309
310
311
312
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
313
            assert len(length_batch) == args.micro_batch_size
314
315
316
317
318
319
320
321
322
323
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
324
            for _ in range(args.micro_batch_size):
325
326
327
328
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
329
330
331
        if ctr >= num_samples:
            break

332

Mohammad's avatar
Mohammad committed
333
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
334

335
    args = get_args()
336
337
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
338
        for datum in generate_samples_unconditional(model):
339
340
341
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
342

343

Mohammad's avatar
Mohammad committed
344
345
def pad_batch(batch, pad_id, args):

346
347
348
349
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
350
            tokens.extend([pad_id] * (args.seq_length - context_length))
351
352
353
        context_lengths.append(context_length)
    return batch, context_lengths

354
355

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
356

357
358
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
359
360
361

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
362
363
364
365

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

366
    torch.distributed.broadcast(context_length_tensor,
367
368
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
369
    torch.distributed.broadcast(context_tokens_tensor,
370
371
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
372
373

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
374
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
375

376
377
378
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
379
380
    for tokens, lengths in batch_token_iterator:
        context_length += 1
381
382
383
384
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
385
386
387


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
388

389
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
390
    return (1 - boolean) * val1 + boolean * val2
391

392

393
394
395
396
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

Jared Casper's avatar
Jared Casper committed
397
398
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
399
400
401
402
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

Jared Casper's avatar
Jared Casper committed
403
    input_tensor = recv_forward()
404
405

    # Forward pass through the model.
Jared Casper's avatar
Jared Casper committed
406
407
408
409
410
411
    model.set_input_tensor(input_tensor)
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
412
413
414
415

    if get_key_value:
        output_tensor, layer_past = output_tensor

Jared Casper's avatar
Jared Casper committed
416
    send_forward(output_tensor)
417

418
    args.seq_length = orig_seq_length
419
420
421
422
423
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


424
425
426
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
427

428
429
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
430

431
432
433
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
Mohammad's avatar
Mohammad committed
434
        eos_id = tokenizer.eod
435
436
437
438
439
440
441
442
443
444
445
446
447

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
448
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
449

450
451
        while context_length <= (maxlen):
            if args.recompute:
452
453
454
455
456
457
458
459
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
460
            else:
461
                types2use = None
462
463
464
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
465
466
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
467
                else:
468
469
470
471
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
472
                    if type_ids is not None:
473
474
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
475
                output, layer_past = forward_step(model, tokens2use,
476
477
478
479
480
481
482
483
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
484
                    logits = output[:, -1].view(batch_size, -1).contiguous()
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

517
            else:
518
519
520
521
522
523
524
525
526
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
527

528
529
530
531
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
532

533
534
            context_length += 1
            counter += 1
535
536
            if done:
                break