text_generation_utils.py 23.6 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Utilities for generating text."""
17

18
import copy
Mohammad's avatar
Mohammad committed
19
20
21
22
import json
import os
import time

23
24
25
import torch
import torch.nn.functional as F

Mohammad's avatar
Mohammad committed
26
27
28
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
29
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
30
from megatron.p2p_communication import recv_forward, send_forward
31

32
33
34
35
36
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module

37
38
39
40
def get_batch(context_tokens):
    """Generate batch from context tokens."""
    args = get_args()
    tokenizer = get_tokenizer()
41

42
    # Move to GPU.
43
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
44
45
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
46
        tokens,
47
        tokenizer.eod,
48
        args.reset_position_ids,
49
        args.reset_attention_mask,
50
        args.eod_mask_loss)
51

52
53
    return tokens, attention_mask, position_ids

54

55
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
56
57
58
59
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """
60
61

    if top_k > 0:
62
63
        # Remove all tokens with a probability less than the
        # last token of the top-k
64
65
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
66

67
    if top_p > 0.0:
68
69
70
71
72
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)
73
74
75

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
76
77
78
79
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
80
        sorted_indices_to_remove[..., 0] = 0
81
82
83
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value
Mohammad's avatar
Mohammad committed
84

85
86
87
    return logits


88
def generate_samples_input_from_file(model):
Mohammad's avatar
Mohammad committed
89

90
91
    args = get_args()
    tokenizer = get_tokenizer()
92

93
94
95
    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
96
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
97
98
99
100
        fname = open(args.sample_input_file, "r")
        all_raw_text = fname.readlines()
        input_count = len(all_raw_text)
        input_pos = 0
101
102
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
103
            print('`sample-output-file` not specified, setting '
Mohammad's avatar
Mohammad committed
104
                  'it to {}'.format(sample_output_file))
Mostofa Patwary's avatar
Mostofa Patwary committed
105
106
        else:
            sample_output_file = args.sample_output_file
107
        fname_out = open(sample_output_file, "w+")
108

Mohammad's avatar
Mohammad committed
109
    context_count = 0
110
111
112
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
113
            terminate_runs = 0
114
            raw_text_len = 0
115

116
117
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
118
119
120
121
                raw_text = all_raw_text[input_pos]
                input_pos += 1
                if input_pos == input_count:
                    raw_text = "stop"
122
                raw_text_len = len(raw_text)
123
124
125
126

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
127
                    context_tokens = tokenizer.tokenize(raw_text)
128
129
                    context_length = len(context_tokens)

130
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
131
132
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
133
                              "sequence length)!", flush=True)
134
135
                        continue
            else:
136
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
137
                context_length = 0
Mohammad's avatar
Mohammad committed
138

139
140
141
142
143
144
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
145
            context_length = input_info_tensor[2].item()
146
147
148
149

            if terminate_runs == 1:
                return

150
151
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
152
153
154
155
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
156
                    group = mpu.get_pipeline_model_parallel_group()
157
158
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
159
                else:
160
                    src = mpu.get_pipeline_model_parallel_first_rank()
161
                    group = mpu.get_pipeline_model_parallel_group()
162
163
164
165
166
167
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

168
            token_stream = get_token_stream(model, [context_tokens])
Mohammad's avatar
Mohammad committed
169
            for _, decode_tokens in enumerate(token_stream):
170
                pass
171

172
            if mpu.get_tensor_model_parallel_rank() == 0:
173
174
175
                if mpu.is_pipeline_first_stage():
                    os.system('clear')
                    print("\nContext:", raw_text, flush=True)
176

177
178
                    fname_out.write("\nContext:")
                    fname_out.write(raw_text)
Mohammad's avatar
Mohammad committed
179

180
181
182
183
184
185
186
187
188
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                    trim_decode_tokens = tokenizer.detokenize(
                        decode_tokens)[raw_text_len:]
                    print("\nMegatron-LM:", trim_decode_tokens, flush=True)

                    fname_out.write("\n\nMegatron-LM:")
                    fname_out.write(trim_decode_tokens)
                    fname_out.write("\n")
189

190
            raw_text = None
191
            context_count += 1
Mohammad's avatar
Mohammad committed
192

Mostofa Patwary's avatar
Mostofa Patwary committed
193
194
195
196
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness 
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
Mostofa Patwary's avatar
Mostofa Patwary committed
197
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    # Generate samples for lm evaluation
    # NEED TO THINK ABOUT eos token

    args = get_args()
    tokenizer = get_tokenizer()

    raw_text_len = len(context)
    model.eval()

    context_tokens = tokenizer.tokenize(context)
    args.out_seq_length = max_gen_length + len(context_tokens)
    args.eos_id = eos_token_id

    with torch.no_grad():
        token_stream = get_token_stream(model, [context_tokens])
        for counter, decode_tokens in enumerate(token_stream):
            if counter == args.out_seq_length:
                break

Mostofa Patwary's avatar
Mostofa Patwary committed
217
218
219
220
221
    decode_tokens, _ = decode_tokens
    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
    trim_decode_tokens = tokenizer.detokenize(
        decode_tokens)[raw_text_len:]
 
222
223
    return trim_decode_tokens

224

225
def generate_samples_interactive(model, print_frequency=24):
Mohammad's avatar
Mohammad committed
226

227
228
    args = get_args()
    tokenizer = get_tokenizer()
229

Mohammad's avatar
Mohammad committed
230
    context_count = 0
231
232
233
    model.eval()
    with torch.no_grad():
        while True:
Mohammad's avatar
Mohammad committed
234
            terminate_runs = 0
235
            raw_text_len = 0
236

237
238
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
239
                os.system('clear')
240
241
242
243
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")
244
                raw_text_len = len(raw_text)
Mohammad's avatar
Mohammad committed
245

246
247
248
                if "stop" in raw_text:
                    terminate_runs = 1
                else:
249
                    context_tokens = tokenizer.tokenize(raw_text)
zihanl's avatar
zihanl committed
250
                    # context_tokens = context_tokens + [tokenizer.sep_id]
251
252
                    context_length = len(context_tokens)

253
                    if context_length >= (args.seq_length // 2):
Neel Kant's avatar
Neel Kant committed
254
255
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
256
                              "sequence length)!", flush=True)
257
258
                        continue
            else:
259
                context_tokens = tokenizer.tokenize("EMPTY TEXT")
260
                context_length = 0
Mohammad's avatar
Mohammad committed
261

262
263
264
265
266
267
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.cuda.LongTensor(input_info)
            torch.distributed.all_reduce(input_info_tensor,
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
268
            context_length = input_info_tensor[2].item()
269
270
271
272

            if terminate_runs == 1:
                return

273
274
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
275
276
277
278
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
279
                    group = mpu.get_pipeline_model_parallel_group()
280
281
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
282
                else:
283
                    src = mpu.get_pipeline_model_parallel_first_rank()
284
                    group = mpu.get_pipeline_model_parallel_group()
285
286
287
288
289
290
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

291
            token_stream = get_token_stream(model, [context_tokens])
292

293
            for counter, decode_tokens in enumerate(token_stream):
294
295
296
297
298
299
300
301
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
                   or not mpu.is_pipeline_first_stage():
                    continue

                os.system('clear')
                print("\nContext:", raw_text, flush=True)

302
                decode_tokens, _ = decode_tokens
zihanl's avatar
zihanl committed
303
304
305
                # print("tokenzied inputs:", tokenizer.tokenize(raw_text))
                # print("decode_tokens:", decode_tokens)

306
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
307
308
                trim_decode_tokens = tokenizer.detokenize(
                    decode_tokens)[raw_text_len:]
zihanl's avatar
zihanl committed
309
310
                # trim_decode_tokens = tokenizer.detokenize(
                #     decode_tokens[context_length:])
311
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)
312

313
314
            if mpu.is_pipeline_first_stage() \
               and mpu.get_tensor_model_parallel_rank() == 0:
315
316
                os.system('clear')
                print("\nContext:", raw_text, flush=True)
317
318
319
320

                if not isinstance(decode_tokens, list):
                    decode_tokens, _ = decode_tokens
                    decode_tokens = decode_tokens[0].cpu().numpy().tolist()
321
                trim_decode_tokens = tokenizer.detokenize(
322
                    decode_tokens)[raw_text_len:]
zihanl's avatar
zihanl committed
323
324
325
                # print("decode_tokens:", decode_tokens)
                # trim_decode_tokens = tokenizer.detokenize(
                #     decode_tokens[context_length:])
326
327
                print("\nMegatron-LM:", trim_decode_tokens, flush=True)

328
329
                input("\nPress Enter to continue >>>")

330
331
            raw_text = None
            context_count += 1
Mohammad's avatar
Mohammad committed
332

333

334
335

def generate_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
336

337
338
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
339

340
    num_samples = args.num_samples
341
    context_tokens = [[tokenizer.eod]
342
                      for _ in range(args.micro_batch_size)]
343
344
345
    ctr = 0
    while True:
        start_time = time.time()
346
347
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
348
            pass
349
350
351
352
353
354
355
356
357
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',
                      (time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
358
            assert len(length_batch) == args.micro_batch_size
359
360
361
362
363
364
365
366
367
368
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
369
            for _ in range(args.micro_batch_size):
370
371
372
373
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
374
375
376
        if ctr >= num_samples:
            break

377

Mohammad's avatar
Mohammad committed
378
def generate_and_write_samples_unconditional(model):
Mohammad's avatar
Mohammad committed
379

380
    args = get_args()
381
382
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
383
        for datum in generate_samples_unconditional(model):
384
385
386
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
387

388

Mohammad's avatar
Mohammad committed
389
390
def pad_batch(batch, pad_id, args):

391
392
393
394
    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        if context_length < args.seq_length:
Neel Kant's avatar
Neel Kant committed
395
            tokens.extend([pad_id] * (args.seq_length - context_length))
396
397
398
        context_lengths.append(context_length)
    return batch, context_lengths

399
400

def get_token_stream(model, context_tokens):
Mohammad's avatar
Mohammad committed
401

402
403
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
404
405
406

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)
407
408
409
410

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

411
    torch.distributed.broadcast(context_length_tensor,
412
413
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
414
    torch.distributed.broadcast(context_tokens_tensor,
415
416
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
417
418

    context_length = context_length_tensor.min().item()
Mohammad's avatar
Mohammad committed
419
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
420

421
422
423
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
424
425
    for tokens, lengths in batch_token_iterator:
        context_length += 1
426
427
428
429
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None
430
431
432


def switch(val1, val2, boolean):
Mohammad's avatar
Mohammad committed
433

434
    boolean = boolean.type_as(val1)
Mohammad's avatar
Mohammad committed
435
    return (1 - boolean) * val1 + boolean * val2
436

437

438
439
440
441
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

Jared Casper's avatar
Jared Casper committed
442
443
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
444
445
446
447
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

Jared Casper's avatar
Jared Casper committed
448
    input_tensor = recv_forward()
449
450

    # Forward pass through the model.
451
452
453
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
Jared Casper's avatar
Jared Casper committed
454
455
456
457
458
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
459
460
461
462

    if get_key_value:
        output_tensor, layer_past = output_tensor

Jared Casper's avatar
Jared Casper committed
463
    send_forward(output_tensor)
464

465
    args.seq_length = orig_seq_length
466
467
468
469
470
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor


471
472
473
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
Mohammad's avatar
Mohammad committed
474

475
476
    args = get_args()
    tokenizer = get_tokenizer()
Mohammad's avatar
Mohammad committed
477

478
479
480
    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
481

Mostofa Patwary's avatar
Mostofa Patwary committed
482
483
        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
484
485
486
487
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod
488
489
490
491
492
493
494
495
496
497
498
499
500

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

Neel Kant's avatar
Neel Kant committed
501
        lengths = torch.ones([batch_size]).long().cuda() * maxlen
Mohammad's avatar
Mohammad committed
502

503
504
        while context_length <= (maxlen):
            if args.recompute:
505
506
507
508
509
510
511
512
                output = forward_step(model, tokens,
                                      position_ids,
                                      attention_mask,
                                      tokentype_ids=type_ids,
                                      forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = output[:, context_length - 1, :]
513
            else:
514
                types2use = None
515
516
517
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    positions2use = position_ids[:, :context_length]
518
519
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
520
                else:
521
522
523
524
                    tokens2use = tokens[:, context_length - 1].view(
                        batch_size, -1)
                    positions2use = position_ids[:, context_length - 1].view(
                        batch_size, -1)
525
                    if type_ids is not None:
526
527
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
528
                output, layer_past = forward_step(model, tokens2use,
529
530
531
532
533
534
535
536
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
                                                  get_key_value=True,
                                                  tokentype_ids=types2use,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
537
                    logits = output[:, -1].view(batch_size, -1).contiguous()
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)

                started = context_lengths <= context_length

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

570
            else:
571
572
573
574
575
576
577
578
579
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None
580

581
582
583
584
                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
585

586
587
            context_length += 1
            counter += 1
588
589
            if done:
                break