text_generation_utils.py 40 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
from megatron import get_args
zihanl's avatar
zihanl committed
27
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
28
29
from megatron import get_tokenizer
from megatron import mpu
30
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
31
from megatron.p2p_communication import recv_forward, send_forward
32

33
34
35
36
37
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

38
39
40
41
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
42

43
    # Move to GPU.
44
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
45
46
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
47
        tokens,
48
        tokenizer.eod,
49
        args.reset_position_ids,
50
        args.reset_attention_mask,
51
        args.eod_mask_loss)
52

53
54
    return tokens, attention_mask, position_ids

55

56
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
57
58
59
60
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
61
62

    if top_k > 0:
63
64
        # Remove all tokens with a probability less than the
        # last token of the top-k
65
66
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
67

68
    if top_p > 0.0:
69
70
71
72
73
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
74
75
76

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
77
78
79
80
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
81
        sorted_indices_to_remove[..., 0] = 0
82
83
84
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
85

86
87
88
    return logits


89
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
90

91
92
    args = get_args()
    tokenizer = get_tokenizer()
93

94
95
96
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
97
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
98
99
100
101
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
102
103
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
104
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
105
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
106
107
        else:
            sample_output_file = args.sample_output_file
108
        fname_out = open(sample_output_file, "w+")
109

Mohammad's avatar
Mohammad committed
110
    context_count = 0
111
112
113
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
114
            terminate_runs = 0
115
            raw_text_len = 0
116

117
118
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
119
120
121
122
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
123
                raw_text_len = len(raw_text)
124
125
126
127

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
128
                    context_tokens = tokenizer.tokenize(raw_text)
129
130
                    context_length = len(context_tokens)

131
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
132
133
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
134
                              "sequence length)!", flush=True)
135
136
                        continue
            else:
137
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
138
                context_length = 0
Mohammad's avatar
Mohammad committed
139

140
141
142
143
144
145
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
146
            context_length = input_info_tensor[2].item()
147
148
149
150

            if terminate_runs == 1:
                return

151
152
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
153
154
155
156
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
157
                    group = mpu.get_pipeline_model_parallel_group()
158
159
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
160
                else:
161
                    src = mpu.get_pipeline_model_parallel_first_rank()
162
                    group = mpu.get_pipeline_model_parallel_group()
163
164
165
166
167
168
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

169
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
170
            for _, decode_tokens in enumerate(token_stream):
171
                pass
172

173
            if mpu.get_tensor_model_parallel_rank() == 0:
174
175
176
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
177

178
179
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
180

181
182
183
184
185
186
187
188
189
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
190

191
            raw_text = None
192
            context_count += 1
Mohammad's avatar
Mohammad committed
193

zihanl's avatar
zihanl committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

def generate_samples_line_by_line_input_from_file(model):

    args = get_args()
    tokenizer = get_tokenizer()

    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('`sample-output-file` not specified, setting '
                    'it to {}'.format(sample_output_file))
        else:
            sample_output_file = args.sample_output_file

        fname_out = open(sample_output_file, "w")

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            raw_text_len = 0

            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                raw_text_len = len(raw_text)
                context_tokens = tokenizer.tokenize(raw_text)
            
            else:
                context_tokens = tokenizer.tokenize("EMPTY TEXT")

            if input_pos % 100 == 0:
                print_rank_0("input_pos: %d" % input_pos)

            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():

                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]

                    if "\r" in trim_decode_tokens:
                        trim_decode_tokens = trim_decode_tokens.replace("\r", "")
                    if "\n" in trim_decode_tokens:
                        trim_decode_tokens = trim_decode_tokens.replace("\n", "")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")

            raw_text = None
            context_count += 1

            if input_pos == input_count:
                return


def generate_samples_prompt_input_from_file(model):

    args = get_args()
    tokenizer = get_tokenizer()
zihanl's avatar
zihanl committed
266
    from nltk import word_tokenize
zihanl's avatar
zihanl committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('`sample-output-file` not specified, setting '
                    'it to {}'.format(sample_output_file))
        else:
            sample_output_file = args.sample_output_file

        fname_out = open(sample_output_file, "w")

    # Read the prompt file
zihanl's avatar
zihanl committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    if args.dynamic_prompt:
        prompt_examples_dict = {}
        with open(args.prompt_file, "r") as f:
            for i, line in enumerate(f):
                line = line.strip()
                line_dict = json.loads(line)
                key = list(line_dict.keys())[0]
                
                if key not in prompt_examples_dict:
                    prompt_examples = line_dict[key]

                    prompt = ""
                    for instance in prompt_examples:
                        instance = instance.strip()
                        prompt += instance + " \n"

                    prompt_examples_dict[key] = prompt

    else:
        with open(args.prompt_file, "r") as f:
            prompt_examples = f.readlines()
            prompt_examples = prompt_examples[:args.num_prompt_examples]

            prompt = ""
            for instance in prompt_examples:
                instance = instance.strip()
                prompt += instance + " \n"

    assert args.prompt_type in ["knowledge", "knowledge_notopic", "dialogue", "dialogue_notopic"]
zihanl's avatar
zihanl committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            raw_text_len = 0

            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                input_str = all_raw_text[input_pos]
                input_str = input_str.strip()
                splits = input_str.split("\t")
                control_codes = splits[0].split(" [CTRL] ")
                topic = control_codes[0]

zihanl's avatar
zihanl committed
329
                if args.dynamic_prompt:
zihanl's avatar
zihanl committed
330
                    turns = splits[1].split(" [SEP] ")
zihanl's avatar
zihanl committed
331
332
333
                    last_turn = turns[-1]
                    key = topic + " " + last_turn
                    raw_text = prompt_examples_dict[key]
zihanl's avatar
zihanl committed
334
335

                else:
zihanl's avatar
zihanl committed
336
                    raw_text = prompt
zihanl's avatar
zihanl committed
337

zihanl's avatar
zihanl committed
338
339
340
341
342
343
344
345
346
347
348
349
                if args.prompt_type == "knowledge":
                    turns = splits[1].split(" [SEP] ")
                    context = turns[-1]
                    raw_text += "( " + context + " ) " + topic + " =>"
                    # raw_text += "( " + context + " ) " + topic + ":"
                    # raw_text += "( " + context + " ) " + topic + " ->"
                
                elif args.prompt_type == "knowledge_notopic":
                    turns = splits[1].split(" [SEP] ")[-3:]
                    for j, turn in enumerate(turns):
                        if j != 0:
                            raw_text += " "
zihanl's avatar
zihanl committed
350
                        else:
zihanl's avatar
zihanl committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                            raw_text += "( " + turn + " )"
                    raw_text += " =>"
                
                elif args.prompt_type == "dialogue":
                    turns = splits[1].split(" [SEP] ")
                    # context = turns[-1]
                    ctrl_sent = splits[2]
                    ctrl_sent = " ".join(word_tokenize(ctrl_sent))

                    # ## version one
                    # turns = turns[-3:]
                    # raw_text += "Topic: " + topic + ". "
                    # if len(turns) == 2:
                    #     for idx, turn in enumerate(turns):
                    #         if idx % 2 == 0:
                    #             raw_text += "System: " + turn + " "
                    #         else:
                    #             raw_text += "User: " + turn + " "
                    # else:
                    #     for idx, turn in enumerate(turns):
                    #         if idx % 2 == 0:
                    #             raw_text += "User: " + turn + " "
                    #         else:
                    #             raw_text += "System: " + turn + " "
                    # raw_text += "We know that: " + ctrl_sent + " "
                    # raw_text += "Therefore, the System will say:"

                    ## version two
                    last_turn = turns[-1]
                    ctrl_sent = ctrl_sent.strip()
                    last_turn = last_turn.strip()
                    raw_text += "Topic: " + topic + ". "
                    raw_text += "User says: " + last_turn + " "
                    raw_text += "We know that: " + ctrl_sent + " "
                    raw_text += "System replies:"
zihanl's avatar
zihanl committed
386

zihanl's avatar
zihanl committed
387
388
389
390
391
392
393
394
395
396
397
398
399
                else:
                    turns = splits[1].split(" [SEP] ")
                    # context = turns[-1]
                    ctrl_sent = splits[2]
                    ctrl_sent = " ".join(word_tokenize(ctrl_sent))

                    ## version two
                    last_turn = turns[-1]
                    ctrl_sent = ctrl_sent.strip()
                    last_turn = last_turn.strip()
                    raw_text += "User says: " + last_turn + " "
                    raw_text += "We know that: " + ctrl_sent + " "
                    raw_text += "System replies:"
zihanl's avatar
zihanl committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621

                input_pos += 1
                raw_text_len = len(raw_text)
                context_tokens = tokenizer.tokenize(raw_text)
            
            else:
                context_tokens = tokenizer.tokenize("EMPTY TEXT")

            if input_pos % 100 == 0:
                print_rank_0("input_pos: %d" % input_pos)

            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass
            
            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():

                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    
                    generated_output = trim_decode_tokens.split("\n")[0]
                    generated_output = generated_output.strip()

                    fname_out.write(generated_output)
                    fname_out.write("\n")

            raw_text = None
            context_count += 1

            if input_pos == input_count:
                return


def dialog_with_gpt_control_interactive(conv_model, ctrl_model, add_separtor):
    args = get_args()
    tokenizer = get_tokenizer()

    conv_model.eval()
    ctrl_model.eval()
    dialog_history = []
    with torch.no_grad():
        while True:
            ctrl_model_input_text_len = 0

            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
                # input @@ to separate the control code and current turn
                input_text = input(">>> ")
                while not input_text:
                    print("Input should not be empty!")
                    input_text = input(">>> ")
                
                assert " @@ " in input_text, "Please input with a correct template"
                splits = input_text.split(" @@ ")
                ctrl_code = splits[0]
                curr_turn = splits[1]
                prev_two_turns = ""
                if add_separtor:
                    for i, turn in enumerate(dialog_history[-2:]):
                        if i == 0:
                            prev_two_turns = "<< " + turn + " >>"
                        else:
                            prev_two_turns += " "
                            prev_two_turns += "<< " + turn + " >>"
                else:
                    prev_two_turns = " ".join(dialog_history[-2:])
                dialog_history.append(curr_turn)

                print("\nHistory:", prev_two_turns)
                print("User:", curr_turn)

                if add_separtor:
                    curr_turn = "<< " + curr_turn + " >>"

                if prev_two_turns != "":
                    dialog_context = prev_two_turns + " " + curr_turn
                else:
                    dialog_context = curr_turn
                ctrl_input = ctrl_code + " " + dialog_context
                
                if add_separtor:
                    ctrl_input += " :"

                ctrl_input_text_len = len(ctrl_input)
                ctrl_context_tokens = tokenizer.tokenize(ctrl_input)

            else:
                ctrl_context_tokens = tokenizer.tokenize("EMPTY TEXT")
            
            token_stream = get_token_stream(ctrl_model, [ctrl_context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    control_sent = tokenizer.detokenize(
                        decode_tokens)[ctrl_input_text_len:]
            
            control_sent = control_sent.replace("<|endoftext|>", "")
            print("\nControl Sentence:", control_sent)
            
            if control_sent != "":
                control_sent = "( " + control_sent + " )"
                conv_input = control_sent + " " + dialog_context
            else:
                conv_input = dialog_context
            
            conv_input_text_len = len(conv_input)
            
            conv_context_tokens = tokenizer.tokenize(conv_input)
            token_stream = get_token_stream(conv_model, [conv_context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    response = tokenizer.detokenize(
                        decode_tokens)[conv_input_text_len:]

            response = response.replace("<|endoftext|>", "")
            print("\nChatbot:", response)
            dialog_history.append(response)


def dialog_with_dpr_control_interactive(conv_model, ctrl_model, ctrl_tokenizer,
                        knowledge_corpus, knowledge_corpus_emb, add_separtor):
    args = get_args()
    tokenizer = get_tokenizer()
    
    conv_model.eval()
    ctrl_model.eval()
    dialog_history = []
    with torch.no_grad():
        while True:
            input_text = input(">>> ")
            while not input_text:
                print("Input should not be empty!")
                input_text = input(">>> ")

            assert " @@ " in input_text, "Please input with a correct template"
            splits = input_text.split(" @@ ")
            ctrl_code = splits[0]
            curr_turn = splits[1]
            prev_two_turns = " ".join(dialog_history[-2:])

            prev_two_turns_v2 = ""
            if add_separtor:
                for i, turn in enumerate(dialog_history[-2:]):
                    if i == 0:
                        prev_two_turns_v2 = "<< " + turn + " >>"
                    else:
                        prev_two_turns_v2 += " "
                        prev_two_turns_v2 += "<< " + turn + " >>"
            else:
                prev_two_turns_v2 = prev_two_turns
            dialog_history.append(curr_turn)

            print("\nHistory:", prev_two_turns_v2)
            print("\nUser:", curr_turn)

            if prev_two_turns != "":
                dialog_context = prev_two_turns + " " + curr_turn
            else:
                dialog_context = curr_turn

            if add_separtor:
                curr_turn = "<< " + curr_turn + " >>"
                dialog_context_v2 = prev_two_turns_v2 + curr_turn
            else:
                dialog_context_v2 = dialog_context

            ctrl_input = ctrl_code + " " + dialog_context

            ctrl_input_ids = ctrl_tokenizer.encode(ctrl_input)
            ctrl_input_ids = torch.LongTensor([ctrl_input_ids]).cuda()
            attn_masks = torch.ones(1, ctrl_input_ids.size()[-1]).cuda()

            query_emb = ctrl_model(input_ids=ctrl_input_ids,
                                   attention_mask=attn_masks).pooler_output # (1,768)

            logits = knowledge_corpus_emb.matmul(query_emb[0])
            retrieved_idx = torch.argmax(logits).item()
            control_sent = knowledge_corpus[retrieved_idx].strip()
            
            print("\nControl Sentence:", control_sent)

            if control_sent != "":
                control_sent = "( " + control_sent + " )"
                conv_input = control_sent + " " + dialog_context_v2
            else:
                conv_input = dialog_context_v2

            conv_input_text_len = len(conv_input)
            
            conv_context_tokens = tokenizer.tokenize(conv_input)
            token_stream = get_token_stream(conv_model, [conv_context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

            if mpu.get_tensor_model_parallel_rank() == 0:
                if mpu.is_pipeline_first_stage():
                    
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    response = tokenizer.detokenize(
                        decode_tokens)[conv_input_text_len:]

            response = response.replace("<|endoftext|>", "")
            print("\nChatbot:", response)
            dialog_history.append(response)



Mostofa Patwary's avatar
Mostofa Patwary committed
622
623
624
625
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness 
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
Mostofa Patwary's avatar
Mostofa Patwary committed
626
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    # Generate samples for lm evaluation
    # NEED TO THINK ABOUT eos token

    args = get_args()
    tokenizer = get_tokenizer()

    raw_text_len = len(context)
    model.eval()

    context_tokens = tokenizer.tokenize(context)
    args.out_seq_length = max_gen_length + len(context_tokens)
    args.eos_id = eos_token_id

    with torch.no_grad():
        token_stream = get_token_stream(model, [context_tokens])
        for counter, decode_tokens in enumerate(token_stream):
            if counter == args.out_seq_length:
                break

Mostofa Patwary's avatar
Mostofa Patwary committed
646
647
648
649
650
    decode_tokens, _ = decode_tokens
    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
    trim_decode_tokens = tokenizer.detokenize(
        decode_tokens)[raw_text_len:]
 
651
652
    return trim_decode_tokens

653

654
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
655

656
657
    args = get_args()
    tokenizer = get_tokenizer()
658

Mohammad's avatar
Mohammad committed
659
    context_count = 0
660
661
662
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
663
            terminate_runs = 0
664
            raw_text_len = 0
665

666
667
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
668
                os.system('clear')
669
670
671
672
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
673
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
674

675
676
677
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
678
                    context_tokens = tokenizer.tokenize(raw_text)
zihanl's avatar
zihanl committed
679
                    # context_tokens = context_tokens + [tokenizer.sep_id]
680
681
                    context_length = len(context_tokens)

682
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
683
684
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
685
                              "sequence length)!", flush=True)
686
687
                        continue
            else:
688
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
689
                context_length = 0
Mohammad's avatar
Mohammad committed
690

691
692
693
694
695
696
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
697
            context_length = input_info_tensor[2].item()
698
699
700
701

            if terminate_runs == 1:
                return

702
703
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
704
705
706
707
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
708
                    group = mpu.get_pipeline_model_parallel_group()
709
710
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
711
                else:
712
                    src = mpu.get_pipeline_model_parallel_first_rank()
713
                    group = mpu.get_pipeline_model_parallel_group()
714
715
716
717
718
719
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

720
            token_stream = get_token_stream(model, [context_tokens])
721

722
            for counter, decode_tokens in enumerate(token_stream):
723
724
725
726
727
728
729
730
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

731
                decode_tokens, _ = decode_tokens
zihanl's avatar
zihanl committed
732
733
734
                # print("tokenzied inputs:", tokenizer.tokenize(raw_text))
                # print("decode_tokens:", decode_tokens)

735
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
736
737
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
zihanl's avatar
zihanl committed
738
739
                # trim_decode_tokens = tokenizer.detokenize(
                #     decode_tokens[context_length:])
740
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
741

742
743
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
744
745
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
746
747
748
749

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
750
                trim_decode_tokens = tokenizer.detokenize(
751
                    decode_tokens)[raw_text_len:]
zihanl's avatar
zihanl committed
752
753
754
                # print("decode_tokens:", decode_tokens)
                # trim_decode_tokens = tokenizer.detokenize(
                #     decode_tokens[context_length:])
755
756
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

757
758
                input("\nPress Enter to continue >>>")

759
760
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
761

762

763
764

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
765

766
767
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
768

769
    num_samples = args.num_samples
770
    context_tokens = [[tokenizer.eod]
771
                      for _ in range(args.micro_batch_size)]
772
773
774
    ctr = 0
    while True:
        start_time = time.time()
775
776
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
777
            pass
778
779
780
781
782
783
784
785
786
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
787
            assert len(length_batch) == args.micro_batch_size
788
789
790
791
792
793
794
795
796
797
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
798
            for _ in range(args.micro_batch_size):
799
800
801
802
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
803
804
805
        if ctr >= num_samples:
            break

806

Mohammad's avatar
Mohammad committed
807
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
808

809
    args = get_args()
810
811
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
812
        for datum in generate_samples_unconditional(model):
813
814
815
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
816

817

Mohammad's avatar
Mohammad committed
818
819
def pad_batch(batch, pad_id, args):

820
821
822
823
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
824
            tokens.extend([pad_id] * (args.seq_length - context_length))
825
826
827
        context_lengths.append(context_length)
    return batch, context_lengths

828
829

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
830

831
832
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
833
834
835

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
836
837
838
839

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

840
    torch.distributed.broadcast(context_length_tensor,
841
842
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
843
    torch.distributed.broadcast(context_tokens_tensor,
844
845
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
846
847

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
848
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
849

850
851
852
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
853
854
    for tokens, lengths in batch_token_iterator:
        context_length += 1
855
856
857
858
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
859
860
861


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
862

863
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
864
    return (1 - boolean) * val1 + boolean * val2
865

866

867
868
869
870
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

Jared Casper's avatar
Jared Casper committed
871
872
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
873
874
875
876
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

Jared Casper's avatar
Jared Casper committed
877
    input_tensor = recv_forward()
878
879

    # Forward pass through the model.
880
881
882
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
Jared Casper's avatar
Jared Casper committed
883
884
885
886
887
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
888
889
890
891

    if get_key_value:
        output_tensor, layer_past = output_tensor

Jared Casper's avatar
Jared Casper committed
892
    send_forward(output_tensor)
893

894
    args.seq_length = orig_seq_length
895
896
897
898
899
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


900
901
902
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
903

904
905
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
906

907
908
909
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
910

Mostofa Patwary's avatar
Mostofa Patwary committed
911
912
        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
913
914
915
916
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod
917
918
919
920
921
922
923
924
925
926
927
928
929

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
930
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
931

932
933
        while context_length <= (maxlen):
            if args.recompute:
934
935
936
937
938
939
940
941
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
942
            else:
943
                types2use = None
944
945
946
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
947
948
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
949
                else:
950
951
952
953
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
954
                    if type_ids is not None:
955
956
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
957
                output, layer_past = forward_step(model, tokens2use,
958
959
960
961
962
963
964
965
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
966
                    logits = output[:, -1].view(batch_size, -1).contiguous()
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

999
            else:
1000
1001
1002
1003
1004
1005
1006
1007
1008
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
1009

1010
1011
1012
1013
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
1014

1015
1016
            context_length += 1
            counter += 1
1017
1018
            if done:
                break