generation.py 19 KB
Newer Older
mshoeybi's avatar
mshoeybi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generation utilities."""

import torch
import torch.nn.functional as F

mshoeybi's avatar
working  
mshoeybi committed
21
from megatron import get_args, get_tokenizer, mpu
mshoeybi's avatar
mshoeybi committed
22
23
24
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
    copy_from_last_to_first_pipeline_stage,
mshoeybi's avatar
working  
mshoeybi committed
25
26
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage)
mshoeybi's avatar
mshoeybi committed
27
from .forward_step import ForwardStep
mshoeybi's avatar
mshoeybi committed
28
from .sampling import sample
Peng Xu's avatar
Peng Xu committed
29
from .beam_utils import BeamHypotheses
mshoeybi's avatar
mshoeybi committed
30

31
32
MAX_TOKENS_TO_OOM = 12000  # (rprenger) Perfect value depends on hardware and network

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def score_and_return_on_first_stage(model, tokens, lengths):
    """Function for just scoring.
    Arguments:
        model: no interleaving is supported.
        tokens: prompt tokens extended to be of size [b, max_prompt_length]
        lengths: original prompt length, size: [b]
    Note: Outside of model, other parameters only need to be available on
          rank 0.
    Outputs: 
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """

    args = get_args()

    batch_size = tokens.size(0)
    max_prompt_length = lengths.max().item()
    assert max_prompt_length == tokens.size(1)
    max_sequence_length = min(max_prompt_length, args.max_position_embeddings)

    # forward step.
    forward_step = ForwardStep(model, batch_size, max_sequence_length)

    # ===================
    # Pre-allocate memory
    # ===================

    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    
    if mpu.is_pipeline_last_stage():
        output_log_probs = torch.empty(output_log_probs_size,
                                       dtype=torch.float32,
                                       device=torch.cuda.current_device())
    
    # =============
    # Run infernece
    # =============
    with torch.no_grad():
        attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
        
        # logits will be meanigful only in the last pipeline stage.
        logits = forward_step(tokens, position_ids, attention_mask)

        if mpu.is_pipeline_last_stage():
            # Always the last stage should have an output.
            assert logits is not None
            log_probs = F.log_softmax(logits, dim=2)
            
            # Pick the tokens that we need to get the log
            # probabilities for. Note that next input token is
            # the token which we selected in the current logits,
            # so shift by 1.
            indices = torch.unsqueeze(tokens[:, 1:], 2)
            output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
    
    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================
    output_log_probs = broadcast_from_last_to_first_pipeline_stage(
        output_log_probs_size, torch.float32, output_log_probs)
    
    return tokens, lengths, output_log_probs
mshoeybi's avatar
mshoeybi committed
96

mshoeybi's avatar
working  
mshoeybi committed
97
98
99
def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
mshoeybi's avatar
mshoeybi committed
100
        top_k=0, top_p=0.0,
mshoeybi's avatar
mshoeybi committed
101
        temperature=1.0,
102
103
104
105
        use_eod_token_for_early_termination=True,
        stop_on_double_eol=False,
        stop_on_eol=False
        ):
mshoeybi's avatar
working  
mshoeybi committed
106
107
    """Main token generation function.
    Arguments:
mshoeybi's avatar
mshoeybi committed
108
        model: no interleaving is supported.
mshoeybi's avatar
working  
mshoeybi committed
109
110
111
112
        tokens: prompt tokens extended to be of size [b, max-sequence-length]
        lengths: original prompt length, size: [b]
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
mshoeybi's avatar
mshoeybi committed
113
114
115
116
117
118
            from the original logit.
        top_k, top_p: top-k and top-p sampling parameters.
            Note that top-k = 1 is gready. Also, these paramters are
            exclusive meaning that:
                if top-k > 0 then we expect top-p=0.
                if top-p > 0 then we check for top-k=0.
mshoeybi's avatar
working  
mshoeybi committed
119
        temperature: sampling temperature.
mshoeybi's avatar
mshoeybi committed
120
121
        use_eod_token_for_early_termination: if True, do early termination if
            all the sequences have reached this token.
mshoeybi's avatar
working  
mshoeybi committed
122
123
124
125
126
127
128
129
130
    Note: Outside of model, other parameters only need to be available on
          rank 0.
    Outputs: Note that is size is adjusted to a lower value than
             max-sequence-length if generation is terminated early.
        tokens: prompt and generated tokens. size: [b, :]
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """
mshoeybi's avatar
mshoeybi committed
131
132
133
134
135
136
137

    args = get_args()
    tokenizer = get_tokenizer()

    batch_size = tokens.size(0)
    min_prompt_length = lengths.min().item()
    max_sequence_length = tokens.size(1)
138
139
140

    if max_sequence_length > args.max_position_embeddings:
        raise ValueError("Length of prompt + tokens_to_generate longer than allowed")
141
    
142
143
    if max_sequence_length * batch_size >= MAX_TOKENS_TO_OOM:
        raise ValueError("Too many tokens.  " + str(max_sequence_length*batch_size)+ " is greater than "+str(MAX_TOKENS_TO_OOM))
mshoeybi's avatar
mshoeybi committed
144

mshoeybi's avatar
mshoeybi committed
145
    # forward step.
mshoeybi's avatar
mshoeybi committed
146
    forward_step = ForwardStep(model, batch_size, max_sequence_length)
mshoeybi's avatar
mshoeybi committed
147

mshoeybi's avatar
mshoeybi committed
148
149
150
151
152
153
154
155
156
157
158
    # Added termination_id to support the case that we want to terminate the
    # generation once that id is generated.
    if hasattr(args, 'eos_id'):
        termination_id = args.eos_id
    else:
        termination_id = tokenizer.eod

    # ===================
    # Pre-allocate memory
    # ===================

mshoeybi's avatar
working  
mshoeybi committed
159
160
161
    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
mshoeybi's avatar
mshoeybi committed
162
    # Lengths of generated seuquence including including prompts.
mshoeybi's avatar
working  
mshoeybi committed
163
164
165
166
167
168
169
    generated_sequence_lengths = None
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = torch.empty(output_log_probs_size,
                                           dtype=torch.float32,
                                           device=torch.cuda.current_device())
        generated_sequence_lengths = torch.ones(
170
171
172
                batch_size, dtype=torch.int64,
                device=torch.cuda.current_device()) * max_sequence_length
    
mshoeybi's avatar
mshoeybi committed
173
174
175
176
    # Whether we have reached a termination id.
    is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
                                     device=torch.cuda.current_device())

mshoeybi's avatar
working  
mshoeybi committed
177
178
179
180
    # =============
    # Run infernece
    # =============

mshoeybi's avatar
mshoeybi committed
181
    with torch.no_grad():
mshoeybi's avatar
mshoeybi committed
182
183
        attention_mask, position_ids = _build_attention_mask_and_position_ids(
            tokens)
mshoeybi's avatar
mshoeybi committed
184
185
186
187
188
189
190
191
192
193
        prev_context_length = 0
        for context_length in range(min_prompt_length, max_sequence_length):

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]
            positions2use = position_ids[:, prev_context_length:context_length]
            attention_mask2use = attention_mask[
                ..., prev_context_length:context_length, :context_length]

            # logits will be meanigful only in the last pipeline stage.
mshoeybi's avatar
mshoeybi committed
194
            logits = forward_step(tokens2use, positions2use, attention_mask2use)
mshoeybi's avatar
mshoeybi committed
195
196
197
198
199
200
201

            if mpu.is_pipeline_last_stage():
                # Always the last stage should have an output.
                assert logits is not None

                # Sample.
                last_token_logits = logits[:, -1, :]
mshoeybi's avatar
mshoeybi committed
202
203
204
205
206
                new_sample = sample(last_token_logits,
                                    top_k=top_k,
                                    top_p=top_p,
                                    temperature=temperature,
                                    vocab_size=tokenizer.vocab_size)
rprenger's avatar
rprenger committed
207
                
mshoeybi's avatar
mshoeybi committed
208
209
210
                # If a prompt length is smaller or equal th current context
                # length, it means we have started generating tokens
                started = lengths <= context_length
mshoeybi's avatar
mshoeybi committed
211
                # Update the tokens.
mshoeybi's avatar
mshoeybi committed
212
213
214
                tokens[started, context_length] = new_sample[started]

                # Calculate the log probabilities.
215
                if return_output_log_probs:
mshoeybi's avatar
working  
mshoeybi committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
                    log_probs = F.log_softmax(logits, dim=2)
                    if return_output_log_probs:
                        # Pick the tokens that we need to get the log
                        # probabilities for. Note that next input token is
                        # the token which we selected in the current logits,
                        # so shift by 1.
                        indices = torch.unsqueeze(
                            tokens[
                                :,
                                (prev_context_length + 1):(context_length + 1)],
                            2)
                        output_log_probs[:,
                                         prev_context_length:context_length] = \
                            torch.gather(log_probs, 2, indices).squeeze(2)
mshoeybi's avatar
mshoeybi committed
230
231
232
233
234
235
236
237
238
239
240
241

            # Update the tokens on the first stage so the next input to
            # the network is correct.
            copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
                                                   tokens[:, context_length])

            # Update the context length for the next token generation.
            prev_context_length = context_length

            # Check if all the sequences have hit the termination_id.
            done = None
            if mpu.is_pipeline_last_stage():
rprenger's avatar
rprenger committed
242
243
                # TODO(rprenger) These stopping methods are tokenizer dependent
                # instead tokenization should be in the inference loop so stop sequences can be used
244
245
246
247
248
249
250
251
252
253
254
255
                if stop_on_double_eol:
                    hit_double_eol = (new_sample == 628).byte() & started.byte()
                    hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
                    done_token = hit_double_eol | hit_two_eols
                elif stop_on_eol:
                    hit_double_eol = (new_sample == 628).byte() & started.byte()
                    hit_eol = (new_sample == 198).byte() & started.byte()
                    done_token = hit_double_eol | hit_eol
                else: 
                    done_token = (new_sample == termination_id).byte() & \
                        started.byte()
                
mshoeybi's avatar
mshoeybi committed
256
257
258
259
260
261
262
                just_finished = (done_token & ~is_generation_done).bool()
                generated_sequence_lengths[just_finished.view(-1)] = \
                    context_length + 1
                is_generation_done = is_generation_done | done_token
                done = torch.all(is_generation_done)
            done = broadcast_from_last_pipeline_stage(1, torch.uint8,
                                                      tensor=done)
mshoeybi's avatar
mshoeybi committed
263
264
            if use_eod_token_for_early_termination and done:
                break
Peng Xu's avatar
Peng Xu committed
265
            
mshoeybi's avatar
working  
mshoeybi committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    # ===================================================
    # Update the length of based on max generated length.
    # ===================================================

    tokens = tokens[:, :(context_length + 1)]
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length]

    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================

    generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
        batch_size, torch.int64, generated_sequence_lengths)
    if return_output_log_probs:
        output_log_probs_size = (batch_size, context_length)
        output_log_probs = broadcast_from_last_to_first_pipeline_stage(
            output_log_probs_size, torch.float32, output_log_probs)
285
286

    return tokens, generated_sequence_lengths, output_log_probs
mshoeybi's avatar
working  
mshoeybi committed
287

288
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty):
rprenger's avatar
rprenger committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    args = get_args()
    tokenizer = get_tokenizer()

    batch_size = tokens.size(0)
    assert(batch_size == 1)
    prompt_length = lengths.item()
    final_sequence_length = tokens.size(1)
    final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
    
    # If the context is too big, this happens
    if prompt_length >= final_sequence_length:
        raise ValueError("context length + tokens_to_generate too large")

    # forward step.
    forward_step = ForwardStep(model, beam_size, final_sequence_length)

305
    beam_hyp = BeamHypotheses(beam_size, length_penalty)
306
307
    best_batches = None
    done = torch.zeros(1, dtype=torch.uint8, device=torch.cuda.current_device())
308
309
310
    scores = torch.zeros(beam_size,
                         dtype=torch.float32,
                         device=torch.cuda.current_device()).unsqueeze(1)
311
    scores_size_tensor, tokens_size_tensor = None, None
rprenger's avatar
rprenger committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    # =============
    # Run infernece
    # =============
    with torch.no_grad():
        tokens = tokens.repeat(beam_size, 1)
        attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
        prev_context_length = 0
        for context_length in range(prompt_length, final_sequence_length):

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]
            positions2use = position_ids[:, prev_context_length:context_length]
            attention_mask2use = attention_mask[
                ..., prev_context_length:context_length, :context_length]

            # logits will be meanigful only in the last pipeline stage.
            logits = forward_step(tokens2use, positions2use, attention_mask2use)

            if mpu.is_pipeline_last_stage():
331
                vocab_size = logits.size(2)
rprenger's avatar
rprenger committed
332
333
334
335
336
337
338
339
                log_probs = F.log_softmax(logits, dim=2)
                new_scores = log_probs[:, -1, :] + scores

                if context_length == prompt_length:  # if this is the first one
                    sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
                else:
                    sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)

Peng Xu's avatar
Peng Xu committed
340
341
342
343
344
345
346
347
348
349
350
351
352
                best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
                best_words = indices[:2 * beam_size] % vocab_size
                best_scores = sorted_scores[: 2 * beam_size]

                next_beams = []
                for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
                    zip(best_words, best_scores, best_beam_ids)
                ):
                    if token_id.item() == stop_token:
                        # if beam_token does not belong to top num_beams tokens, it should not be added
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
                        if is_beam_token_worse_than_top_num_beams:
                            continue
353
                        beam_hyp.add(
Peng Xu's avatar
Peng Xu committed
354
355
356
357
358
359
360
361
362
363
364
                            tokens[beam_id].clone(),
                            beam_score,
                            context_length + 1 - prompt_length
                        )
                    else:
                        # add next predicted token since it is not eos_token
                        next_beams.append((token_id, beam_score, beam_id))

                    if len(next_beams) == beam_size:
                        break

365
                if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
366
367
                    done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
            
Peng Xu's avatar
Peng Xu committed
368
                best_batches = tokens.new([item[2] for item in next_beams])
rprenger's avatar
rprenger committed
369
                tokens = tokens[best_batches,:]
Peng Xu's avatar
Peng Xu committed
370
371
                tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
                scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
372
373
374
375
376
          
            # torch.distributed.barrier()
            done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
            if done:
                break
Peng Xu's avatar
Peng Xu committed
377

rprenger's avatar
rprenger committed
378
379
            # Update the tokens on the first stage so the next input to
            # the network is correct.
380
381
382
383
384
385
            copy_from_last_to_first_pipeline_stage(tokens.size(), torch.int64,
                                                   tokens)

            # set inference key values to make it consistent with best beam index
            best_batches = broadcast_from_last_pipeline_stage(beam_size, torch.int64, best_batches)
            forward_step.inference_params.swap_key_value_dict(best_batches)
rprenger's avatar
rprenger committed
386
387
388

            # Update the context length for the next token generation.
            prev_context_length = context_length
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410

        if mpu.is_pipeline_last_stage():
            # if cannot find stop token, add open beams to hyps
            if not done:
                for beam_id in range(beam_size):
                    beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)

            # rank based on scores
            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
            num_return_gen = min(num_return_gen, len(sorted_hyps))
            scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
            tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
            scores = torch.stack(scores, dim=0)
            tokens = torch.stack(tokens, dim=0)
            scores_size_tensor = torch.tensor(scores.shape, dtype=torch.int64, device=torch.cuda.current_device())
            tokens_size_tensor = torch.tensor(tokens.shape, dtype=torch.int64, device=torch.cuda.current_device())

        scores_size_tensor = broadcast_from_last_pipeline_stage(1, torch.int64, scores_size_tensor)
        tokens_size_tensor = broadcast_from_last_pipeline_stage(2, torch.int64, tokens_size_tensor)

        scores = broadcast_from_last_to_first_pipeline_stage(tuple(scores_size_tensor), torch.float32, scores)
        tokens = broadcast_from_last_to_first_pipeline_stage(tuple(tokens_size_tensor), torch.int64, tokens)
Peng Xu's avatar
Peng Xu committed
411

rprenger's avatar
rprenger committed
412
413
    return tokens, scores

mshoeybi's avatar
mshoeybi committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427

def _build_attention_mask_and_position_ids(tokens):
    """Build the attention mask and postition ids for the input tokens."""

    # Since we are not interested in loss-mask and reset attention/position
    # is also False, eod_token is not used so it is safe to set it to None.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        data=tokens,
        eod_token=None,
        reset_position_ids=False,
        reset_attention_mask=False,
        eod_mask_loss=False)

    return attention_mask, position_ids