generation.py 20.7 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mshoeybi's avatar
mshoeybi committed
2
3
4
5
6
7

"""Generation utilities."""

import torch
import torch.nn.functional as F

xingjinliang's avatar
xingjinliang committed
8
from megatron.training import get_args, get_tokenizer
9
from megatron.core import mpu
xingjinliang's avatar
xingjinliang committed
10
from megatron.training.utils import get_ltor_masks_and_position_ids
mshoeybi's avatar
mshoeybi committed
11
12
from .communication import (
    copy_from_last_to_first_pipeline_stage,
mshoeybi's avatar
working  
mshoeybi committed
13
14
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage)
mshoeybi's avatar
mshoeybi committed
15
from .forward_step import ForwardStep
mshoeybi's avatar
mshoeybi committed
16
from .sampling import sample
Peng Xu's avatar
Peng Xu committed
17
from .beam_utils import BeamHypotheses
mshoeybi's avatar
mshoeybi committed
18

xingjinliang's avatar
xingjinliang committed
19
20
21
22
MAX_TOPK_LOGPROBS = 5
NO_TOPK_LOGPROBS = None

def score_and_return_on_first_stage(model, tokens: torch.Tensor, lengths: torch.Tensor):
23
    """Function for just scoring.
xingjinliang's avatar
xingjinliang committed
24
25

    Args:
26
27
28
29
30
        model: no interleaving is supported.
        tokens: prompt tokens extended to be of size [b, max_prompt_length]
        lengths: original prompt length, size: [b]
    Note: Outside of model, other parameters only need to be available on
          rank 0.
xingjinliang's avatar
xingjinliang committed
31
32

    Returns:
33
34
35
36
37
38
39
40
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """

    args = get_args()

    batch_size = tokens.size(0)
    max_prompt_length = lengths.max().item()
    assert max_prompt_length == tokens.size(1)
xingjinliang's avatar
xingjinliang committed
41

42
    if max_prompt_length > args.max_position_embeddings:
xingjinliang's avatar
xingjinliang committed
43
44
45
46
        raise ValueError(
            f"Length of prompt + tokens_to_generate longer than allowed {max_prompt_length} > {args.max_position_embeddings}"
        )

47
    if max_prompt_length * batch_size > args.max_tokens_to_oom:
xingjinliang's avatar
xingjinliang committed
48
49
50
        raise ValueError(
            f"Too many tokens.  {max_prompt_length*batch_size} > {args.max_tokens_to_oom}"
        )
51
52

    # forward step.
xingjinliang's avatar
xingjinliang committed
53
    forward_step = ForwardStep(model, batch_size, args.inference_max_seq_length)
54
55
56
57
58
59
60

    # ===================
    # Pre-allocate memory
    # ===================

    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
xingjinliang's avatar
xingjinliang committed
61
    output_topk_log_probs, output_topk_log_indices = None, None
62
    output_log_probs_size = (batch_size, max_prompt_length - 1)
xingjinliang's avatar
xingjinliang committed
63
64
    output_topk_log_probs_size = (batch_size, max_prompt_length, MAX_TOPK_LOGPROBS)

65
    if mpu.is_pipeline_last_stage():
xingjinliang's avatar
xingjinliang committed
66
67
68
69
70
71
72
73
74
75
76
        output_log_probs = torch.empty(
            output_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()
        )

        output_topk_log_probs = torch.empty(
            output_topk_log_probs_size, dtype=torch.float32, device=torch.cuda.current_device()
        )

        output_topk_log_indices = torch.empty(
            output_topk_log_probs_size, dtype=torch.int64, device=torch.cuda.current_device()
        )
77
78
79
80
81
    # =============
    # Run infernece
    # =============
    with torch.no_grad():
        attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
xingjinliang's avatar
xingjinliang committed
82

83
84
85
86
87
88
89
        # logits will be meanigful only in the last pipeline stage.
        logits = forward_step(tokens, position_ids, attention_mask)

        if mpu.is_pipeline_last_stage():
            # Always the last stage should have an output.
            assert logits is not None
            log_probs = F.log_softmax(logits, dim=2)
xingjinliang's avatar
xingjinliang committed
90

91
92
93
94
95
96
            # Pick the tokens that we need to get the log
            # probabilities for. Note that next input token is
            # the token which we selected in the current logits,
            # so shift by 1.
            indices = torch.unsqueeze(tokens[:, 1:], 2)
            output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
xingjinliang's avatar
xingjinliang committed
97
98
            torch.topk(log_probs, MAX_TOPK_LOGPROBS, dim=2, out=(output_topk_log_probs, output_topk_log_indices))

99
100
101
    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================
xingjinliang's avatar
xingjinliang committed
102
103
104
105
106
107
    output_topk_log_probs = broadcast_from_last_to_first_pipeline_stage(
        output_topk_log_probs_size, torch.float32, output_topk_log_probs
    )
    output_topk_log_indices = broadcast_from_last_to_first_pipeline_stage(
        output_topk_log_probs_size, torch.int64, output_topk_log_indices
    )
108
    output_log_probs = broadcast_from_last_to_first_pipeline_stage(
xingjinliang's avatar
xingjinliang committed
109
110
111
112
113
        output_log_probs_size, torch.float32, output_log_probs
    )

    logprobs_topk = torch.return_types.topk((output_topk_log_probs, output_topk_log_indices))
    return tokens, lengths, output_log_probs, logprobs_topk
mshoeybi's avatar
mshoeybi committed
114

mshoeybi's avatar
working  
mshoeybi committed
115
def generate_tokens_probs_and_return_on_first_stage(
xingjinliang's avatar
xingjinliang committed
116
        model, forward_step, tokens, lengths,
mshoeybi's avatar
working  
mshoeybi committed
117
        return_output_log_probs=False,
118
        top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,
mshoeybi's avatar
mshoeybi committed
119
        temperature=1.0,
120
121
        use_eod_token_for_early_termination=True,
        stop_on_double_eol=False,
Peng Xu's avatar
Peng Xu committed
122
123
        stop_on_eol=False,
        prevent_newline_after_colon=True
124
        ):
mshoeybi's avatar
working  
mshoeybi committed
125
    """Main token generation function.
xingjinliang's avatar
xingjinliang committed
126
127

    Args:
mshoeybi's avatar
mshoeybi committed
128
        model: no interleaving is supported.
xingjinliang's avatar
xingjinliang committed
129
        forward_step (ForwardStep): Class for running the model forward step.
mshoeybi's avatar
working  
mshoeybi committed
130
131
132
133
        tokens: prompt tokens extended to be of size [b, max-sequence-length]
        lengths: original prompt length, size: [b]
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
mshoeybi's avatar
mshoeybi committed
134
135
136
137
138
139
            from the original logit.
        top_k, top_p: top-k and top-p sampling parameters.
            Note that top-k = 1 is gready. Also, these paramters are
            exclusive meaning that:
                if top-k > 0 then we expect top-p=0.
                if top-p > 0 then we check for top-k=0.
mshoeybi's avatar
working  
mshoeybi committed
140
        temperature: sampling temperature.
mshoeybi's avatar
mshoeybi committed
141
142
        use_eod_token_for_early_termination: if True, do early termination if
            all the sequences have reached this token.
Peng Xu's avatar
Peng Xu committed
143
        prevent_newline_after_colon: if True, it will disable generating new line \n after :
mshoeybi's avatar
working  
mshoeybi committed
144
145
    Note: Outside of model, other parameters only need to be available on
          rank 0.
xingjinliang's avatar
xingjinliang committed
146
147

    Returns: Note that is size is adjusted to a lower value than
mshoeybi's avatar
working  
mshoeybi committed
148
149
150
151
152
153
             max-sequence-length if generation is terminated early.
        tokens: prompt and generated tokens. size: [b, :]
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """
mshoeybi's avatar
mshoeybi committed
154
155
156
157
158
159
160

    args = get_args()
    tokenizer = get_tokenizer()

    batch_size = tokens.size(0)
    min_prompt_length = lengths.min().item()
    max_sequence_length = tokens.size(1)
161
162
163

    if max_sequence_length > args.max_position_embeddings:
        raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
xingjinliang's avatar
xingjinliang committed
164

165
166
    if max_sequence_length * batch_size > args.max_tokens_to_oom:
        raise ValueError("Too many tokens.  " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom))
mshoeybi's avatar
mshoeybi committed
167

mshoeybi's avatar
mshoeybi committed
168
    # forward step.
xingjinliang's avatar
xingjinliang committed
169
    forward_step = forward_step(model, batch_size, args.inference_max_seq_length)
mshoeybi's avatar
mshoeybi committed
170

mshoeybi's avatar
mshoeybi committed
171
172
173
174
    # Added termination_id to support the case that we want to terminate the
    # generation once that id is generated.
    if hasattr(args, 'eos_id'):
        termination_id = args.eos_id
xingjinliang's avatar
xingjinliang committed
175
    elif hasattr(tokenizer, 'eod'):
mshoeybi's avatar
mshoeybi committed
176
        termination_id = tokenizer.eod
xingjinliang's avatar
xingjinliang committed
177
178
179
180
    elif hasattr(tokenizer, 'eos_id'):
        termination_id = tokenizer.eos_id
    else:
        raise AttributeError('No eod token found in tokenizer or args')
mshoeybi's avatar
mshoeybi committed
181
182
183
184
185

    # ===================
    # Pre-allocate memory
    # ===================

mshoeybi's avatar
working  
mshoeybi committed
186
187
188
    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
mshoeybi's avatar
mshoeybi committed
189
    # Lengths of generated seuquence including including prompts.
mshoeybi's avatar
working  
mshoeybi committed
190
191
192
193
194
195
196
    generated_sequence_lengths = None
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = torch.empty(output_log_probs_size,
                                           dtype=torch.float32,
                                           device=torch.cuda.current_device())
        generated_sequence_lengths = torch.ones(
197
198
                batch_size, dtype=torch.int64,
                device=torch.cuda.current_device()) * max_sequence_length
xingjinliang's avatar
xingjinliang committed
199

mshoeybi's avatar
mshoeybi committed
200
201
202
203
    # Whether we have reached a termination id.
    is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
                                     device=torch.cuda.current_device())

mshoeybi's avatar
working  
mshoeybi committed
204
205
206
207
    # =============
    # Run infernece
    # =============

mshoeybi's avatar
mshoeybi committed
208
    with torch.no_grad():
mshoeybi's avatar
mshoeybi committed
209
210
        attention_mask, position_ids = _build_attention_mask_and_position_ids(
            tokens)
mshoeybi's avatar
mshoeybi committed
211
212
213
214
215
216
217
218
219
220
        prev_context_length = 0
        for context_length in range(min_prompt_length, max_sequence_length):

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]
            positions2use = position_ids[:, prev_context_length:context_length]
            attention_mask2use = attention_mask[
                ..., prev_context_length:context_length, :context_length]

            # logits will be meanigful only in the last pipeline stage.
mshoeybi's avatar
mshoeybi committed
221
            logits = forward_step(tokens2use, positions2use, attention_mask2use)
mshoeybi's avatar
mshoeybi committed
222
223

            if mpu.is_pipeline_last_stage():
Peng Xu's avatar
Peng Xu committed
224
225
                if prevent_newline_after_colon:
                    logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
mshoeybi's avatar
mshoeybi committed
226
227
228
229
230
                # Always the last stage should have an output.
                assert logits is not None

                # Sample.
                last_token_logits = logits[:, -1, :]
mshoeybi's avatar
mshoeybi committed
231
232
233
234
235
                new_sample = sample(last_token_logits,
                                    top_k=top_k,
                                    top_p=top_p,
                                    temperature=temperature,
                                    vocab_size=tokenizer.vocab_size)
236
237
238
239
                if top_p > 0.0 and top_p_decay > 0.0:
                    top_p = top_p * top_p_decay
                    if top_p_bound > 0.0:
                        top_p = max(top_p, top_p_bound)
240

mshoeybi's avatar
mshoeybi committed
241
242
243
                # If a prompt length is smaller or equal th current context
                # length, it means we have started generating tokens
                started = lengths <= context_length
mshoeybi's avatar
mshoeybi committed
244
                # Update the tokens.
mshoeybi's avatar
mshoeybi committed
245
246
247
                tokens[started, context_length] = new_sample[started]

                # Calculate the log probabilities.
248
                if return_output_log_probs:
mshoeybi's avatar
working  
mshoeybi committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                    log_probs = F.log_softmax(logits, dim=2)
                    if return_output_log_probs:
                        # Pick the tokens that we need to get the log
                        # probabilities for. Note that next input token is
                        # the token which we selected in the current logits,
                        # so shift by 1.
                        indices = torch.unsqueeze(
                            tokens[
                                :,
                                (prev_context_length + 1):(context_length + 1)],
                            2)
                        output_log_probs[:,
                                         prev_context_length:context_length] = \
                            torch.gather(log_probs, 2, indices).squeeze(2)
mshoeybi's avatar
mshoeybi committed
263
264
265
266
267
268
269
270
271
272
273
274

            # Update the tokens on the first stage so the next input to
            # the network is correct.
            copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
                                                   tokens[:, context_length])

            # Update the context length for the next token generation.
            prev_context_length = context_length

            # Check if all the sequences have hit the termination_id.
            done = None
            if mpu.is_pipeline_last_stage():
rprenger's avatar
rprenger committed
275
276
                # TODO(rprenger) These stopping methods are tokenizer dependent
                # instead tokenization should be in the inference loop so stop sequences can be used
277
278
279
280
281
282
283
284
                if stop_on_double_eol:
                    hit_double_eol = (new_sample == 628).byte() & started.byte()
                    hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
                    done_token = hit_double_eol | hit_two_eols
                elif stop_on_eol:
                    hit_double_eol = (new_sample == 628).byte() & started.byte()
                    hit_eol = (new_sample == 198).byte() & started.byte()
                    done_token = hit_double_eol | hit_eol
xingjinliang's avatar
xingjinliang committed
285
                else:
286
287
                    done_token = (new_sample == termination_id).byte() & \
                        started.byte()
xingjinliang's avatar
xingjinliang committed
288

mshoeybi's avatar
mshoeybi committed
289
290
291
292
293
294
295
                just_finished = (done_token & ~is_generation_done).bool()
                generated_sequence_lengths[just_finished.view(-1)] = \
                    context_length + 1
                is_generation_done = is_generation_done | done_token
                done = torch.all(is_generation_done)
            done = broadcast_from_last_pipeline_stage(1, torch.uint8,
                                                      tensor=done)
mshoeybi's avatar
mshoeybi committed
296
297
            if use_eod_token_for_early_termination and done:
                break
xingjinliang's avatar
xingjinliang committed
298

mshoeybi's avatar
working  
mshoeybi committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    # ===================================================
    # Update the length of based on max generated length.
    # ===================================================

    tokens = tokens[:, :(context_length + 1)]
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length]

    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================

    generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
        batch_size, torch.int64, generated_sequence_lengths)
    if return_output_log_probs:
        output_log_probs_size = (batch_size, context_length)
        output_log_probs = broadcast_from_last_to_first_pipeline_stage(
            output_log_probs_size, torch.float32, output_log_probs)
318

xingjinliang's avatar
xingjinliang committed
319
    return tokens, generated_sequence_lengths, output_log_probs, NO_TOPK_LOGPROBS
mshoeybi's avatar
working  
mshoeybi committed
320

xingjinliang's avatar
xingjinliang committed
321
def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty, prevent_newline_after_colon=True):
rprenger's avatar
rprenger committed
322
323
324
325
326
327
328
329
    args = get_args()
    tokenizer = get_tokenizer()

    batch_size = tokens.size(0)
    assert(batch_size == 1)
    prompt_length = lengths.item()
    final_sequence_length = tokens.size(1)
    final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
xingjinliang's avatar
xingjinliang committed
330

rprenger's avatar
rprenger committed
331
332
333
334
335
    # If the context is too big, this happens
    if prompt_length >= final_sequence_length:
        raise ValueError("context length + tokens_to_generate too large")

    # forward step.
xingjinliang's avatar
xingjinliang committed
336
    forward_step = forward_step(model, beam_size, final_sequence_length)
rprenger's avatar
rprenger committed
337

338
    beam_hyp = BeamHypotheses(beam_size, length_penalty)
339
340
    best_batches = None
    done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
341
342
343
    scores = torch.zeros(beam_size,
                         dtype=torch.float32,
                         device=torch.cuda.current_device()).unsqueeze(1)
344
    scores_size_tensor, tokens_size_tensor = None, None
rprenger's avatar
rprenger committed
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    # =============
    # Run infernece
    # =============
    with torch.no_grad():
        tokens = tokens.repeat(beam_size, 1)
        attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
        prev_context_length = 0
        for context_length in range(prompt_length, final_sequence_length):

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]
            positions2use = position_ids[:, prev_context_length:context_length]
            attention_mask2use = attention_mask[
                ..., prev_context_length:context_length, :context_length]

            # logits will be meanigful only in the last pipeline stage.
            logits = forward_step(tokens2use, positions2use, attention_mask2use)

            if mpu.is_pipeline_last_stage():
Peng Xu's avatar
Peng Xu committed
364
365
                if prevent_newline_after_colon:
                    logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
366
                vocab_size = logits.size(2)
rprenger's avatar
rprenger committed
367
368
369
370
371
372
373
374
                log_probs = F.log_softmax(logits, dim=2)
                new_scores = log_probs[:, -1, :] + scores

                if context_length == prompt_length:  # if this is the first one
                    sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
                else:
                    sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)

Peng Xu's avatar
Peng Xu committed
375
376
377
378
379
380
381
382
383
384
385
386
387
                best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
                best_words = indices[:2 * beam_size] % vocab_size
                best_scores = sorted_scores[: 2 * beam_size]

                next_beams = []
                for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
                    zip(best_words, best_scores, best_beam_ids)
                ):
                    if token_id.item() == stop_token:
                        # if beam_token does not belong to top num_beams tokens, it should not be added
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
                        if is_beam_token_worse_than_top_num_beams:
                            continue
388
                        beam_hyp.add(
Peng Xu's avatar
Peng Xu committed
389
390
391
392
393
394
395
396
397
398
399
                            tokens[beam_id].clone(),
                            beam_score,
                            context_length + 1 - prompt_length
                        )
                    else:
                        # add next predicted token since it is not eos_token
                        next_beams.append((token_id, beam_score, beam_id))

                    if len(next_beams) == beam_size:
                        break

400
                if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
401
                    done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
xingjinliang's avatar
xingjinliang committed
402

Peng Xu's avatar
Peng Xu committed
403
                best_batches = tokens.new([item[2] for item in next_beams])
rprenger's avatar
rprenger committed
404
                tokens = tokens[best_batches,:]
Peng Xu's avatar
Peng Xu committed
405
406
                tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
                scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
xingjinliang's avatar
xingjinliang committed
407

408
409
410
411
            # torch.distributed.barrier()
            done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
            if done:
                break
Peng Xu's avatar
Peng Xu committed
412

rprenger's avatar
rprenger committed
413
414
            # Update the tokens on the first stage so the next input to
            # the network is correct.
415
416
417
418
419
420
            copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
                                                   tokens)

            # set inference key values to make it consistent with best beam index
            best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
            forward_step.inference_params.swap_key_value_dict(best_batches)
rprenger's avatar
rprenger committed
421
422
423

            # Update the context length for the next token generation.
            prev_context_length = context_length
424
425
426
427
428

        if mpu.is_pipeline_last_stage():
            # if cannot find stop token, add open beams to hyps
            if not done:
                for beam_id in range(beam_size):
Peng Xu's avatar
Peng Xu committed
429
                    beam_hyp.add(tokens[beam_id].clone(), scores[beam_id].squeeze(), context_length + 1 - prompt_length)
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445

            # rank based on scores
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
            num_return_gen = min(num_return_gen, len(sorted_hyps))
            scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
            tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
            scores = torch.stack(scores, dim=0)
            tokens = torch.stack(tokens, dim=0)
            scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
            tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())

        scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
        tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)

        scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
        tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
Peng Xu's avatar
Peng Xu committed
446

rprenger's avatar
rprenger committed
447
    return tokens, scores
mshoeybi's avatar
mshoeybi committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462


def _build_attention_mask_and_position_ids(tokens):
    """Build the attention mask and postition ids for the input tokens."""

    # Since we are not interested in loss-mask and reset attention/position
    # is also False, eod_token is not used so it is safe to set it to None.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        data=tokens,
        eod_token=None,
        reset_position_ids=False,
        reset_attention_mask=False,
        eod_mask_loss=False)

    return attention_mask, position_ids