text_generation_utils.py 36.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
from megatron import get_args
zihanl's avatar
zihanl committed
27
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
28
29
from megatron import get_tokenizer
from megatron import mpu
30
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
31
from megatron.p2p_communication import recv_forward, send_forward
32

33
34
35
36
37
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

38
39
40
41
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
42

43
    # Move to GPU.
44
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
45
46
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
47
        tokens,
48
        tokenizer.eod,
49
        args.reset_position_ids,
50
        args.reset_attention_mask,
51
        args.eod_mask_loss)
52

53
54
    return tokens, attention_mask, position_ids

55

56
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
57
58
59
60
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
61
62

    if top_k > 0:
63
64
        # Remove all tokens with a probability less than the
        # last token of the top-k
65
66
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
67

68
    if top_p > 0.0:
69
70
71
72
73
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
74
75
76

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
77
78
79
80
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
81
        sorted_indices_to_remove[..., 0] = 0
82
83
84
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
85

86
87
88
    return logits


89
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
90

91
92
    args = get_args()
    tokenizer = get_tokenizer()
93

94
95
96
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
97
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
98
99
100
101
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
102
103
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
104
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
105
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
106
107
        else:
            sample_output_file = args.sample_output_file
108
        fname_out = open(sample_output_file, "w+")
109

Mohammad's avatar
Mohammad committed
110
    context_count = 0
111
112
113
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
114
            terminate_runs = 0
115
            raw_text_len = 0
116

117
118
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
119
120
121
122
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
123
                raw_text_len = len(raw_text)
124
125
126
127

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
128
                    context_tokens = tokenizer.tokenize(raw_text)
129
130
                    context_length = len(context_tokens)

131
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
132
133
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
134
                              "sequence length)!", flush=True)
135
136
                        continue
            else:
137
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
138
                context_length = 0
Mohammad's avatar
Mohammad committed
139

140
141
142
143
144
145
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
146
            context_length = input_info_tensor[2].item()
147
148
149
150

            if terminate_runs == 1:
                return

151
152
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
153
154
155
156
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
157
                    group = mpu.get_pipeline_model_parallel_group()
158
159
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
160
                else:
161
                    src = mpu.get_pipeline_model_parallel_first_rank()
162
                    group = mpu.get_pipeline_model_parallel_group()
163
164
165
166
167
168
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

169
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
170
            for _, decode_tokens in enumerate(token_stream):
171
                pass
172

173
            if mpu.get_tensor_model_parallel_rank() == 0:
174
175
176
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
177

178
179
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
180

181
182
183
184
185
186
187
188
189
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
190

191
            raw_text = None
192
            context_count += 1
Mohammad's avatar
Mohammad committed
193

zihanl's avatar
zihanl committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549

def generate_samples_line_by_line_input_from_file(model):

    args = get_args()
    tokenizer = get_tokenizer()

    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('`sample-output-file` not specified, setting '
                    'it to {}'.format(sample_output_file))
        else:
            sample_output_file = args.sample_output_file

        fname_out = open(sample_output_file, "w")

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            raw_text_len = 0

            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                raw_text_len = len(raw_text)
                context_tokens = tokenizer.tokenize(raw_text)
            
            else:
                context_tokens = tokenizer.tokenize("EMPTY TEXT")

            if input_pos % 100 == 0:
                print_rank_0("input_pos: %d" % input_pos)

            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():

                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]

                    if "\r" in trim_decode_tokens:
                        trim_decode_tokens = trim_decode_tokens.replace("\r", "")
                    if "\n" in trim_decode_tokens:
                        trim_decode_tokens = trim_decode_tokens.replace("\n", "")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")

            raw_text = None
            context_count += 1

            if input_pos == input_count:
                return


def generate_samples_prompt_input_from_file(model):

    args = get_args()
    tokenizer = get_tokenizer()

    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('`sample-output-file` not specified, setting '
                    'it to {}'.format(sample_output_file))
        else:
            sample_output_file = args.sample_output_file

        fname_out = open(sample_output_file, "w")

    # Read the prompt file
    with open(args.prompt_file, "r") as f:
        prompt_examples = f.readlines()

    prompt_examples = prompt_examples[:args.num_prompt_examples]
    prompt = ""
    for instance in prompt_examples:
        instance = instance.strip()
        prompt += instance + " \n"

    assert args.prompt_type in ["context", "keyphrase"]
    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            raw_text_len = 0

            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                input_str = all_raw_text[input_pos]
                input_str = input_str.strip()
                splits = input_str.split("\t")
                control_codes = splits[0].split(" [CTRL] ")
                topic = control_codes[0]

                raw_text = prompt
                if args.prompt_type == "context":
                    turns = splits[1].split(" [SEP] ")
                    context = turns[-1]
                    raw_text += "( " + context + " ) " + topic + " :"

                else:
                    keyphrase_list = control_codes[1:]

                    for i, keyphrase in enumerate(keyphrase_list):
                        if i == 0:
                            raw_text += "( "
                        else:
                            raw_text += "; "
                        raw_text += keyphrase

                    if len(keyphrase_list) > 0:
                        raw_text += " ) "
                    raw_text += topic + " :"

                input_pos += 1
                raw_text_len = len(raw_text)
                context_tokens = tokenizer.tokenize(raw_text)
            
            else:
                context_tokens = tokenizer.tokenize("EMPTY TEXT")

            if input_pos % 100 == 0:
                print_rank_0("input_pos: %d" % input_pos)

            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass
            
            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():

                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    
                    generated_output = trim_decode_tokens.split("\n")[0]
                    generated_output = generated_output.strip()

                    fname_out.write(generated_output)
                    fname_out.write("\n")

            raw_text = None
            context_count += 1

            if input_pos == input_count:
                return


def dialog_with_gpt_control_interactive(conv_model, ctrl_model, add_separtor):
    args = get_args()
    tokenizer = get_tokenizer()

    conv_model.eval()
    ctrl_model.eval()
    dialog_history = []
    with torch.no_grad():
        while True:
            ctrl_model_input_text_len = 0

            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                # input @@ to separate the control code and current turn
                input_text = input(">>> ")
                while not input_text:
                    print("Input should not be empty!")
                    input_text = input(">>> ")
                
                assert " @@ " in input_text, "Please input with a correct template"
                splits = input_text.split(" @@ ")
                ctrl_code = splits[0]
                curr_turn = splits[1]
                prev_two_turns = ""
                if add_separtor:
                    for i, turn in enumerate(dialog_history[-2:]):
                        if i == 0:
                            prev_two_turns = "<< " + turn + " >>"
                        else:
                            prev_two_turns += " "
                            prev_two_turns += "<< " + turn + " >>"
                else:
                    prev_two_turns = " ".join(dialog_history[-2:])
                dialog_history.append(curr_turn)

                print("\nHistory:", prev_two_turns)
                print("User:", curr_turn)

                if add_separtor:
                    curr_turn = "<< " + curr_turn + " >>"

                if prev_two_turns != "":
                    dialog_context = prev_two_turns + " " + curr_turn
                else:
                    dialog_context = curr_turn
                ctrl_input = ctrl_code + " " + dialog_context
                
                if add_separtor:
                    ctrl_input += " :"

                ctrl_input_text_len = len(ctrl_input)
                ctrl_context_tokens = tokenizer.tokenize(ctrl_input)

            else:
                ctrl_context_tokens = tokenizer.tokenize("EMPTY TEXT")
            
            token_stream = get_token_stream(ctrl_model, [ctrl_context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    control_sent = tokenizer.detokenize(
                        decode_tokens)[ctrl_input_text_len:]
            
            control_sent = control_sent.replace("<|endoftext|>", "")
            print("\nControl Sentence:", control_sent)
            
            if control_sent != "":
                control_sent = "( " + control_sent + " )"
                conv_input = control_sent + " " + dialog_context
            else:
                conv_input = dialog_context
            
            conv_input_text_len = len(conv_input)
            
            conv_context_tokens = tokenizer.tokenize(conv_input)
            token_stream = get_token_stream(conv_model, [conv_context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    response = tokenizer.detokenize(
                        decode_tokens)[conv_input_text_len:]

            response = response.replace("<|endoftext|>", "")
            print("\nChatbot:", response)
            dialog_history.append(response)


def dialog_with_dpr_control_interactive(conv_model, ctrl_model, ctrl_tokenizer,
                        knowledge_corpus, knowledge_corpus_emb, add_separtor):
    args = get_args()
    tokenizer = get_tokenizer()
    
    conv_model.eval()
    ctrl_model.eval()
    dialog_history = []
    with torch.no_grad():
        while True:
            input_text = input(">>> ")
            while not input_text:
                print("Input should not be empty!")
                input_text = input(">>> ")

            assert " @@ " in input_text, "Please input with a correct template"
            splits = input_text.split(" @@ ")
            ctrl_code = splits[0]
            curr_turn = splits[1]
            prev_two_turns = " ".join(dialog_history[-2:])

            prev_two_turns_v2 = ""
            if add_separtor:
                for i, turn in enumerate(dialog_history[-2:]):
                    if i == 0:
                        prev_two_turns_v2 = "<< " + turn + " >>"
                    else:
                        prev_two_turns_v2 += " "
                        prev_two_turns_v2 += "<< " + turn + " >>"
            else:
                prev_two_turns_v2 = prev_two_turns
            dialog_history.append(curr_turn)

            print("\nHistory:", prev_two_turns_v2)
            print("\nUser:", curr_turn)

            if prev_two_turns != "":
                dialog_context = prev_two_turns + " " + curr_turn
            else:
                dialog_context = curr_turn

            if add_separtor:
                curr_turn = "<< " + curr_turn + " >>"
                dialog_context_v2 = prev_two_turns_v2 + curr_turn
            else:
                dialog_context_v2 = dialog_context

            ctrl_input = ctrl_code + " " + dialog_context

            ctrl_input_ids = ctrl_tokenizer.encode(ctrl_input)
            ctrl_input_ids = torch.LongTensor([ctrl_input_ids]).cuda()
            attn_masks = torch.ones(1, ctrl_input_ids.size()[-1]).cuda()

            query_emb = ctrl_model(input_ids=ctrl_input_ids,
                                   attention_mask=attn_masks).pooler_output # (1,768)

            logits = knowledge_corpus_emb.matmul(query_emb[0])
            retrieved_idx = torch.argmax(logits).item()
            control_sent = knowledge_corpus[retrieved_idx].strip()
            
            print("\nControl Sentence:", control_sent)

            if control_sent != "":
                control_sent = "( " + control_sent + " )"
                conv_input = control_sent + " " + dialog_context_v2
            else:
                conv_input = dialog_context_v2

            conv_input_text_len = len(conv_input)
            
            conv_context_tokens = tokenizer.tokenize(conv_input)
            token_stream = get_token_stream(conv_model, [conv_context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    response = tokenizer.detokenize(
                        decode_tokens)[conv_input_text_len:]

            response = response.replace("<|endoftext|>", "")
            print("\nChatbot:", response)
            dialog_history.append(response)



Mostofa Patwary's avatar
Mostofa Patwary committed
550
551
552
553
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness 
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
Mostofa Patwary's avatar
Mostofa Patwary committed
554
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    # Generate samples for lm evaluation
    # NEED TO THINK ABOUT eos token

    args = get_args()
    tokenizer = get_tokenizer()

    raw_text_len = len(context)
    model.eval()

    context_tokens = tokenizer.tokenize(context)
    args.out_seq_length = max_gen_length + len(context_tokens)
    args.eos_id = eos_token_id

    with torch.no_grad():
        token_stream = get_token_stream(model, [context_tokens])
        for counter, decode_tokens in enumerate(token_stream):
            if counter == args.out_seq_length:
                break

Mostofa Patwary's avatar
Mostofa Patwary committed
574
575
576
577
578
    decode_tokens, _ = decode_tokens
    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
    trim_decode_tokens = tokenizer.detokenize(
        decode_tokens)[raw_text_len:]
 
579
580
    return trim_decode_tokens

581

582
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
583

584
585
    args = get_args()
    tokenizer = get_tokenizer()
586

Mohammad's avatar
Mohammad committed
587
    context_count = 0
588
589
590
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
591
            terminate_runs = 0
592
            raw_text_len = 0
593

594
595
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
596
                os.system('clear')
597
598
599
600
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
601
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
602

603
604
605
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
606
                    context_tokens = tokenizer.tokenize(raw_text)
zihanl's avatar
zihanl committed
607
                    # context_tokens = context_tokens + [tokenizer.sep_id]
608
609
                    context_length = len(context_tokens)

610
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
611
612
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
613
                              "sequence length)!", flush=True)
614
615
                        continue
            else:
616
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
617
                context_length = 0
Mohammad's avatar
Mohammad committed
618

619
620
621
622
623
624
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
625
            context_length = input_info_tensor[2].item()
626
627
628
629

            if terminate_runs == 1:
                return

630
631
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
632
633
634
635
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
636
                    group = mpu.get_pipeline_model_parallel_group()
637
638
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
639
                else:
640
                    src = mpu.get_pipeline_model_parallel_first_rank()
641
                    group = mpu.get_pipeline_model_parallel_group()
642
643
644
645
646
647
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

648
            token_stream = get_token_stream(model, [context_tokens])
649

650
            for counter, decode_tokens in enumerate(token_stream):
651
652
653
654
655
656
657
658
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

659
                decode_tokens, _ = decode_tokens
zihanl's avatar
zihanl committed
660
661
662
                # print("tokenzied inputs:", tokenizer.tokenize(raw_text))
                # print("decode_tokens:", decode_tokens)

663
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
664
665
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
zihanl's avatar
zihanl committed
666
667
                # trim_decode_tokens = tokenizer.detokenize(
                #     decode_tokens[context_length:])
668
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
669

670
671
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
672
673
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
674
675
676
677

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
678
                trim_decode_tokens = tokenizer.detokenize(
679
                    decode_tokens)[raw_text_len:]
zihanl's avatar
zihanl committed
680
681
682
                # print("decode_tokens:", decode_tokens)
                # trim_decode_tokens = tokenizer.detokenize(
                #     decode_tokens[context_length:])
683
684
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

685
686
                input("\nPress Enter to continue >>>")

687
688
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
689

690

691
692

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
693

694
695
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
696

697
    num_samples = args.num_samples
698
    context_tokens = [[tokenizer.eod]
699
                      for _ in range(args.micro_batch_size)]
700
701
702
    ctr = 0
    while True:
        start_time = time.time()
703
704
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
705
            pass
706
707
708
709
710
711
712
713
714
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
715
            assert len(length_batch) == args.micro_batch_size
716
717
718
719
720
721
722
723
724
725
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
726
            for _ in range(args.micro_batch_size):
727
728
729
730
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
731
732
733
        if ctr >= num_samples:
            break

734

Mohammad's avatar
Mohammad committed
735
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
736

737
    args = get_args()
738
739
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
740
        for datum in generate_samples_unconditional(model):
741
742
743
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
744

745

Mohammad's avatar
Mohammad committed
746
747
def pad_batch(batch, pad_id, args):

748
749
750
751
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
752
            tokens.extend([pad_id] * (args.seq_length - context_length))
753
754
755
        context_lengths.append(context_length)
    return batch, context_lengths

756
757

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
758

759
760
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
761
762
763

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
764
765
766
767

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

768
    torch.distributed.broadcast(context_length_tensor,
769
770
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
771
    torch.distributed.broadcast(context_tokens_tensor,
772
773
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
774
775

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
776
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
777

778
779
780
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
781
782
    for tokens, lengths in batch_token_iterator:
        context_length += 1
783
784
785
786
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
787
788
789


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
790

791
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
792
    return (1 - boolean) * val1 + boolean * val2
793

794

795
796
797
798
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

Jared Casper's avatar
Jared Casper committed
799
800
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
801
802
803
804
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

Jared Casper's avatar
Jared Casper committed
805
    input_tensor = recv_forward()
806
807

    # Forward pass through the model.
808
809
810
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
Jared Casper's avatar
Jared Casper committed
811
812
813
814
815
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
816
817
818
819

    if get_key_value:
        output_tensor, layer_past = output_tensor

Jared Casper's avatar
Jared Casper committed
820
    send_forward(output_tensor)
821

822
    args.seq_length = orig_seq_length
823
824
825
826
827
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


828
829
830
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
831

832
833
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
834

835
836
837
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
838

Mostofa Patwary's avatar
Mostofa Patwary committed
839
840
        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
841
842
843
844
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod
845
846
847
848
849
850
851
852
853
854
855
856
857

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
858
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
859

860
861
        while context_length <= (maxlen):
            if args.recompute:
862
863
864
865
866
867
868
869
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
870
            else:
871
                types2use = None
872
873
874
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
875
876
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
877
                else:
878
879
880
881
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
882
                    if type_ids is not None:
883
884
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
885
                output, layer_past = forward_step(model, tokens2use,
886
887
888
889
890
891
892
893
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
894
                    logits = output[:, -1].view(batch_size, -1).contiguous()
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

927
            else:
928
929
930
931
932
933
934
935
936
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
937

938
939
940
941
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
942

943
944
            context_length += 1
            counter += 1
945
946
            if done:
                break