text_generation_utils.py 23.6 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
from megatron import get_args
zihanl's avatar
zihanl committed
27
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
28
29
from megatron import get_tokenizer
from megatron import mpu
30
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
31
from megatron.p2p_communication import recv_forward, send_forward
32

33
34
35
36
37
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

38
39
40
41
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
42

43
    # Move to GPU.
44
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
45
46
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
47
        tokens,
48
        tokenizer.eod,
49
        args.reset_position_ids,
50
        args.reset_attention_mask,
51
        args.eod_mask_loss)
52

53
54
    return tokens, attention_mask, position_ids

55

56
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
57
58
59
60
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
61
62

    if top_k > 0:
63
64
        # Remove all tokens with a probability less than the
        # last token of the top-k
65
66
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
67

68
    if top_p > 0.0:
69
70
71
72
73
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
74
75
76

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
77
78
79
80
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
81
        sorted_indices_to_remove[..., 0] = 0
82
83
84
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
85

86
87
88
    return logits


89
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
90

91
92
    args = get_args()
    tokenizer = get_tokenizer()
93

94
95
96
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
97
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
98
99
100
101
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
102
103
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
104
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
105
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
106
107
        else:
            sample_output_file = args.sample_output_file
108
        fname_out = open(sample_output_file, "w+")
109

Mohammad's avatar
Mohammad committed
110
    context_count = 0
111
112
113
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
114
            terminate_runs = 0
115
            raw_text_len = 0
116

117
118
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
119
120
121
122
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
123
                raw_text_len = len(raw_text)
124
125
126
127

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
128
                    context_tokens = tokenizer.tokenize(raw_text)
129
130
                    context_length = len(context_tokens)

131
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
132
133
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
134
                              "sequence length)!", flush=True)
135
136
                        continue
            else:
137
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
138
                context_length = 0
Mohammad's avatar
Mohammad committed
139

140
141
142
143
144
145
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
146
            context_length = input_info_tensor[2].item()
147
148
149
150

            if terminate_runs == 1:
                return

151
152
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
153
154
155
156
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
157
                    group = mpu.get_pipeline_model_parallel_group()
158
159
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
160
                else:
161
                    src = mpu.get_pipeline_model_parallel_first_rank()
162
                    group = mpu.get_pipeline_model_parallel_group()
163
164
165
166
167
168
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

169
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
170
            for _, decode_tokens in enumerate(token_stream):
171
                pass
172

173
            if mpu.get_tensor_model_parallel_rank() == 0:
174
175
176
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
177

178
179
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
180

181
182
183
184
185
186
187
188
189
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
190

191
            raw_text = None
192
            context_count += 1
Mohammad's avatar
Mohammad committed
193

zihanl's avatar
zihanl committed
194

Mostofa Patwary's avatar
Mostofa Patwary committed
195
196
197
198
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness 
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
Mostofa Patwary's avatar
Mostofa Patwary committed
199
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    # Generate samples for lm evaluation
    # NEED TO THINK ABOUT eos token

    args = get_args()
    tokenizer = get_tokenizer()

    raw_text_len = len(context)
    model.eval()

    context_tokens = tokenizer.tokenize(context)
    args.out_seq_length = max_gen_length + len(context_tokens)
    args.eos_id = eos_token_id

    with torch.no_grad():
        token_stream = get_token_stream(model, [context_tokens])
        for counter, decode_tokens in enumerate(token_stream):
            if counter == args.out_seq_length:
                break

Mostofa Patwary's avatar
Mostofa Patwary committed
219
220
221
222
223
    decode_tokens, _ = decode_tokens
    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
    trim_decode_tokens = tokenizer.detokenize(
        decode_tokens)[raw_text_len:]
 
224
225
    return trim_decode_tokens

226

227
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
228

229
230
    args = get_args()
    tokenizer = get_tokenizer()
231

Mohammad's avatar
Mohammad committed
232
    context_count = 0
233
234
235
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
236
            terminate_runs = 0
237
            raw_text_len = 0
238

239
240
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
241
                os.system('clear')
242
243
244
245
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
246
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
247

248
249
250
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
251
                    context_tokens = tokenizer.tokenize(raw_text)
zihanl's avatar
zihanl committed
252
                    # context_tokens = context_tokens + [tokenizer.sep_id]
253
254
                    context_length = len(context_tokens)

255
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
256
257
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
258
                              "sequence length)!", flush=True)
259
260
                        continue
            else:
261
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
262
                context_length = 0
Mohammad's avatar
Mohammad committed
263

264
265
266
267
268
269
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
270
            context_length = input_info_tensor[2].item()
271
272
273
274

            if terminate_runs == 1:
                return

275
276
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
277
278
279
280
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
281
                    group = mpu.get_pipeline_model_parallel_group()
282
283
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
284
                else:
285
                    src = mpu.get_pipeline_model_parallel_first_rank()
286
                    group = mpu.get_pipeline_model_parallel_group()
287
288
289
290
291
292
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

293
            token_stream = get_token_stream(model, [context_tokens])
294

295
            for counter, decode_tokens in enumerate(token_stream):
296
297
298
299
300
301
302
303
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

304
                decode_tokens, _ = decode_tokens
zihanl's avatar
zihanl committed
305
306
307
                # print("tokenzied inputs:", tokenizer.tokenize(raw_text))
                # print("decode_tokens:", decode_tokens)

308
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
309
310
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
zihanl's avatar
zihanl committed
311
312
                # trim_decode_tokens = tokenizer.detokenize(
                #     decode_tokens[context_length:])
313
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
314

315
316
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
317
318
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
319
320
321
322

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
323
                trim_decode_tokens = tokenizer.detokenize(
324
                    decode_tokens)[raw_text_len:]
zihanl's avatar
zihanl committed
325
326
327
                # print("decode_tokens:", decode_tokens)
                # trim_decode_tokens = tokenizer.detokenize(
                #     decode_tokens[context_length:])
328
329
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

330
331
                input("\nPress Enter to continue >>>")

332
333
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
334

335

336
337

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
338

339
340
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
341

342
    num_samples = args.num_samples
343
    context_tokens = [[tokenizer.eod]
344
                      for _ in range(args.micro_batch_size)]
345
346
347
    ctr = 0
    while True:
        start_time = time.time()
348
349
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
350
            pass
351
352
353
354
355
356
357
358
359
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
360
            assert len(length_batch) == args.micro_batch_size
361
362
363
364
365
366
367
368
369
370
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
371
            for _ in range(args.micro_batch_size):
372
373
374
375
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
376
377
378
        if ctr >= num_samples:
            break

379

Mohammad's avatar
Mohammad committed
380
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
381

382
    args = get_args()
383
384
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
385
        for datum in generate_samples_unconditional(model):
386
387
388
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
389

390

Mohammad's avatar
Mohammad committed
391
392
def pad_batch(batch, pad_id, args):

393
394
395
396
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
397
            tokens.extend([pad_id] * (args.seq_length - context_length))
398
399
400
        context_lengths.append(context_length)
    return batch, context_lengths

401
402

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
403

404
405
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
406
407
408

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
409
410
411
412

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

413
    torch.distributed.broadcast(context_length_tensor,
414
415
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
416
    torch.distributed.broadcast(context_tokens_tensor,
417
418
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
419
420

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
421
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
422

423
424
425
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
426
427
    for tokens, lengths in batch_token_iterator:
        context_length += 1
428
429
430
431
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
432
433
434


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
435

436
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
437
    return (1 - boolean) * val1 + boolean * val2
438

439

440
441
442
443
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

Jared Casper's avatar
Jared Casper committed
444
445
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
446
447
448
449
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

Jared Casper's avatar
Jared Casper committed
450
    input_tensor = recv_forward()
451
452

    # Forward pass through the model.
453
454
455
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
Jared Casper's avatar
Jared Casper committed
456
457
458
459
460
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
461
462
463
464

    if get_key_value:
        output_tensor, layer_past = output_tensor

Jared Casper's avatar
Jared Casper committed
465
    send_forward(output_tensor)
466

467
    args.seq_length = orig_seq_length
468
469
470
471
472
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


473
474
475
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
476

477
478
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
479

480
481
482
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
483

Mostofa Patwary's avatar
Mostofa Patwary committed
484
485
        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
486
487
488
489
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod
490
491
492
493
494
495
496
497
498
499
500
501
502

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
503
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
504

505
506
        while context_length <= (maxlen):
            if args.recompute:
507
508
509
510
511
512
513
514
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
515
            else:
516
                types2use = None
517
518
519
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
520
521
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
522
                else:
523
524
525
526
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
527
                    if type_ids is not None:
528
529
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
530
                output, layer_past = forward_step(model, tokens2use,
531
532
533
534
535
536
537
538
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
539
                    logits = output[:, -1].view(batch_size, -1).contiguous()
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

572
            else:
573
574
575
576
577
578
579
580
581
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
582

583
584
585
586
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
587

588
589
            context_length += 1
            counter += 1
590
591
            if done:
                break