"tools/setup_helpers/extension.py" did not exist on "092a786984d4e69812e4de538c7a40921d5d1281"
generation.py 9.12 KB
Newer Older
mshoeybi's avatar
mshoeybi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generation utilities."""

import torch
import torch.nn.functional as F

mshoeybi's avatar
working  
mshoeybi committed
21
from megatron import get_args, get_tokenizer, mpu
mshoeybi's avatar
mshoeybi committed
22
23
24
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
    copy_from_last_to_first_pipeline_stage,
mshoeybi's avatar
working  
mshoeybi committed
25
26
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage)
mshoeybi's avatar
mshoeybi committed
27
from .forward_step import ForwardStep
mshoeybi's avatar
mshoeybi committed
28
29
30
from .sampling import sample


mshoeybi's avatar
working  
mshoeybi committed
31
32
33
def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
mshoeybi's avatar
mshoeybi committed
34
        top_k=0, top_p=0.0,
mshoeybi's avatar
mshoeybi committed
35
36
        temperature=1.0,
        use_eod_token_for_early_termination=True):
mshoeybi's avatar
working  
mshoeybi committed
37
38
    """Main token generation function.
    Arguments:
mshoeybi's avatar
mshoeybi committed
39
        model: no interleaving is supported.
mshoeybi's avatar
working  
mshoeybi committed
40
41
42
43
        tokens: prompt tokens extended to be of size [b, max-sequence-length]
        lengths: original prompt length, size: [b]
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
mshoeybi's avatar
mshoeybi committed
44
45
46
47
48
49
            from the original logit.
        top_k, top_p: top-k and top-p sampling parameters.
            Note that top-k = 1 is gready. Also, these paramters are
            exclusive meaning that:
                if top-k > 0 then we expect top-p=0.
                if top-p > 0 then we check for top-k=0.
mshoeybi's avatar
working  
mshoeybi committed
50
        temperature: sampling temperature.
mshoeybi's avatar
mshoeybi committed
51
52
        use_eod_token_for_early_termination: if True, do early termination if
            all the sequences have reached this token.
mshoeybi's avatar
working  
mshoeybi committed
53
54
55
56
57
58
59
60
61
    Note: Outside of model, other parameters only need to be available on
          rank 0.
    Outputs: Note that is size is adjusted to a lower value than
             max-sequence-length if generation is terminated early.
        tokens: prompt and generated tokens. size: [b, :]
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """
mshoeybi's avatar
mshoeybi committed
62
63
64
65
66
67
68
69
70

    args = get_args()
    tokenizer = get_tokenizer()

    batch_size = tokens.size(0)
    min_prompt_length = lengths.min().item()
    max_sequence_length = tokens.size(1)
    max_sequence_length = min(max_sequence_length, args.max_position_embeddings)

mshoeybi's avatar
mshoeybi committed
71
    # forward step.
mshoeybi's avatar
mshoeybi committed
72
    forward_step = ForwardStep(model, batch_size, max_sequence_length)
mshoeybi's avatar
mshoeybi committed
73

mshoeybi's avatar
mshoeybi committed
74
75
76
77
78
79
80
81
82
83
84
    # Added termination_id to support the case that we want to terminate the
    # generation once that id is generated.
    if hasattr(args, 'eos_id'):
        termination_id = args.eos_id
    else:
        termination_id = tokenizer.eod

    # ===================
    # Pre-allocate memory
    # ===================

mshoeybi's avatar
working  
mshoeybi committed
85
86
87
    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
mshoeybi's avatar
mshoeybi committed
88
    # Lengths of generated seuquence including including prompts.
mshoeybi's avatar
working  
mshoeybi committed
89
90
91
92
93
94
95
96
97
    generated_sequence_lengths = None
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = torch.empty(output_log_probs_size,
                                           dtype=torch.float32,
                                           device=torch.cuda.current_device())
        generated_sequence_lengths = torch.ones(
            batch_size, dtype=torch.int64,
            device=torch.cuda.current_device()) * max_sequence_length
mshoeybi's avatar
mshoeybi committed
98
99
100
101
    # Whether we have reached a termination id.
    is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
                                     device=torch.cuda.current_device())

mshoeybi's avatar
working  
mshoeybi committed
102
103
104
105
    # =============
    # Run infernece
    # =============

mshoeybi's avatar
mshoeybi committed
106
    with torch.no_grad():
mshoeybi's avatar
mshoeybi committed
107
108
        attention_mask, position_ids = _build_attention_mask_and_position_ids(
            tokens)
mshoeybi's avatar
mshoeybi committed
109
110
111
112
113
114
115
116
117
118
        prev_context_length = 0
        for context_length in range(min_prompt_length, max_sequence_length):

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]
            positions2use = position_ids[:, prev_context_length:context_length]
            attention_mask2use = attention_mask[
                ..., prev_context_length:context_length, :context_length]

            # logits will be meanigful only in the last pipeline stage.
mshoeybi's avatar
mshoeybi committed
119
            logits = forward_step(tokens2use, positions2use, attention_mask2use)
mshoeybi's avatar
mshoeybi committed
120
121
122
123
124
125
126

            if mpu.is_pipeline_last_stage():
                # Always the last stage should have an output.
                assert logits is not None

                # Sample.
                last_token_logits = logits[:, -1, :]
mshoeybi's avatar
mshoeybi committed
127
128
129
130
131
                new_sample = sample(last_token_logits,
                                    top_k=top_k,
                                    top_p=top_p,
                                    temperature=temperature,
                                    vocab_size=tokenizer.vocab_size)
mshoeybi's avatar
mshoeybi committed
132
133
134
                # If a prompt length is smaller or equal th current context
                # length, it means we have started generating tokens
                started = lengths <= context_length
mshoeybi's avatar
mshoeybi committed
135
                # Update the tokens.
mshoeybi's avatar
mshoeybi committed
136
137
138
                tokens[started, context_length] = new_sample[started]

                # Calculate the log probabilities.
mshoeybi's avatar
mshoeybi committed
139
                if return_output_log_probs:
mshoeybi's avatar
working  
mshoeybi committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                    log_probs = F.log_softmax(logits, dim=2)
                    if return_output_log_probs:
                        # Pick the tokens that we need to get the log
                        # probabilities for. Note that next input token is
                        # the token which we selected in the current logits,
                        # so shift by 1.
                        indices = torch.unsqueeze(
                            tokens[
                                :,
                                (prev_context_length + 1):(context_length + 1)],
                            2)
                        output_log_probs[:,
                                         prev_context_length:context_length] = \
                            torch.gather(log_probs, 2, indices).squeeze(2)
mshoeybi's avatar
mshoeybi committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

            # Update the tokens on the first stage so the next input to
            # the network is correct.
            copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
                                                   tokens[:, context_length])

            # Update the context length for the next token generation.
            prev_context_length = context_length

            # Check if all the sequences have hit the termination_id.
            done = None
            if mpu.is_pipeline_last_stage():
                done_token = (new_sample == termination_id).byte() & \
                    started.byte()
                just_finished = (done_token & ~is_generation_done).bool()
                generated_sequence_lengths[just_finished.view(-1)] = \
                    context_length + 1
                is_generation_done = is_generation_done | done_token
                done = torch.all(is_generation_done)
            done = broadcast_from_last_pipeline_stage(1, torch.uint8,
                                                      tensor=done)
mshoeybi's avatar
mshoeybi committed
175
176
            if use_eod_token_for_early_termination and done:
                break
mshoeybi's avatar
mshoeybi committed
177

mshoeybi's avatar
working  
mshoeybi committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    # ===================================================
    # Update the length of based on max generated length.
    # ===================================================

    tokens = tokens[:, :(context_length + 1)]
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length]

    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================

    generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
        batch_size, torch.int64, generated_sequence_lengths)
    if return_output_log_probs:
        output_log_probs_size = (batch_size, context_length)
        output_log_probs = broadcast_from_last_to_first_pipeline_stage(
            output_log_probs_size, torch.float32, output_log_probs)
mshoeybi's avatar
mshoeybi committed
197
198

    return tokens, generated_sequence_lengths, output_log_probs
mshoeybi's avatar
working  
mshoeybi committed
199

mshoeybi's avatar
mshoeybi committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214


def _build_attention_mask_and_position_ids(tokens):
    """Build the attention mask and postition ids for the input tokens."""

    # Since we are not interested in loss-mask and reset attention/position
    # is also False, eod_token is not used so it is safe to set it to None.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        data=tokens,
        eod_token=None,
        reset_position_ids=False,
        reset_attention_mask=False,
        eod_mask_loss=False)

    return attention_mask, position_ids