forward_step.py 7.8 KB
Newer Older
mshoeybi's avatar
mshoeybi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Forward step utilities."""

mshoeybi's avatar
mshoeybi committed
18
from collections.abc import Iterable
mshoeybi's avatar
mshoeybi committed
19

mshoeybi's avatar
mshoeybi committed
20
import torch
mshoeybi's avatar
mshoeybi committed
21

mshoeybi's avatar
mshoeybi committed
22
23
24
from megatron import (
    get_args,
    mpu)
mshoeybi's avatar
mshoeybi committed
25
26
27
from .communication import (
    send_to_next_pipeline_rank,
    recv_from_prev_pipeline_rank_)
mshoeybi's avatar
mshoeybi committed
28
29
30



mshoeybi's avatar
mshoeybi committed
31
class InferenceParams:
mshoeybi's avatar
mshoeybi committed
32
33
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
mshoeybi's avatar
mshoeybi committed
34

mshoeybi's avatar
mshoeybi committed
35
    def __init__(self, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
36
37
38
        """Note that offsets are set to zero and we always set the
        flag to allocate memory. After the first call, make sure to
        set this flag to False."""
mshoeybi's avatar
mshoeybi committed
39
        self.max_sequence_len = max_sequence_len
mshoeybi's avatar
mshoeybi committed
40
41
42
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
mshoeybi's avatar
mshoeybi committed
43
        self.allocate_key_value_memory = True
mshoeybi's avatar
mshoeybi committed
44
45


mshoeybi's avatar
mshoeybi committed
46

mshoeybi's avatar
mshoeybi committed
47
class ForwardStep:
mshoeybi's avatar
mshoeybi committed
48
49
50
    """Forward step function with all the communications.
    We use a class here to hide the inference parameters
    from the outside caller."""
mshoeybi's avatar
mshoeybi committed
51

mshoeybi's avatar
mshoeybi committed
52
    def __init__(self, model, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
53
        """Set values so we don't need to do it multiple times."""
mshoeybi's avatar
mshoeybi committed
54
        # Make sure model is in eval mode.
mshoeybi's avatar
mshoeybi committed
55
56
57
        assert not isinstance(model, Iterable), \
            'interleaving schedule is not supported for inference'
        model.eval()
mshoeybi's avatar
mshoeybi committed
58
        self.model = model
mshoeybi's avatar
mshoeybi committed
59
60
61
        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
                                                max_sequence_len)
mshoeybi's avatar
mshoeybi committed
62
63
        # Pipelining arguments.
        args = get_args()
mshoeybi's avatar
mshoeybi committed
64
65
        self.pipeline_size_larger_than_one = (
            args.pipeline_model_parallel_size > 1)
mshoeybi's avatar
mshoeybi committed
66
67
68
        # Threshold of pipelining.
        self.pipelining_batch_x_seqlen = \
            args.inference_batch_times_seqlen_threshold
mshoeybi's avatar
mshoeybi committed
69

mshoeybi's avatar
mshoeybi committed
70

mshoeybi's avatar
mshoeybi committed
71
    def __call__(self, tokens, position_ids, attention_mask):
mshoeybi's avatar
mshoeybi committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        """Invocation of the forward methods. Note that self.inference_params
        is being modified by the forward step."""
        # Pipelining case.
        if self.pipeline_size_larger_than_one:
            current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
            if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
                micro_batch_size = \
                    max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
                return _with_pipelining_forward_step(self.model,
                                                     tokens,
                                                     position_ids,
                                                     attention_mask,
                                                     self.inference_params,
                                                     micro_batch_size)

        return _no_pipelining_forward_step(self.model,
                                           tokens,
                                           position_ids,
                                           attention_mask,
                                           self.inference_params)

mshoeybi's avatar
mshoeybi committed
93
94


mshoeybi's avatar
mshoeybi committed
95
96
97
98
99
def _get_recv_buffer_dtype(args):
    """Receive happens between the layers."""
    if args.fp32_residual_connection:
        return torch.float
    return args.params_dtype
mshoeybi's avatar
mshoeybi committed
100
101
102



mshoeybi's avatar
mshoeybi committed
103
104
105
106
107
108
109
110
111
def _allocate_recv_buffer(batch_size, sequence_length):
    """Receive happens between the layers with size [s, b, h]."""
    if mpu.is_pipeline_first_stage():
        return None
    args = get_args()
    recv_size = (sequence_length, batch_size, args.hidden_size)
    return torch.empty(recv_size,
                       dtype=_get_recv_buffer_dtype(args),
                       device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
112
113
114



mshoeybi's avatar
mshoeybi committed
115
116
117
118
119
120
121
122
def _forward_step_helper(model, tokens, position_ids, attention_mask,
                         inference_params, recv_buffer=None):
    """Single forward step. Update the allocate memory flag so
    only the first time the memory is allocated."""
    batch_size = tokens.size(0)
    sequence_length = tokens.size(1)
    if recv_buffer is None:
        recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
mshoeybi's avatar
mshoeybi committed
123

mshoeybi's avatar
mshoeybi committed
124
    # Receive from previous stage.
mshoeybi's avatar
mshoeybi committed
125
    recv_from_prev_pipeline_rank_(recv_buffer)
mshoeybi's avatar
mshoeybi committed
126

mshoeybi's avatar
mshoeybi committed
127
128
129
130
    # Forward pass through the model.
    model.set_input_tensor(recv_buffer)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
131

mshoeybi's avatar
mshoeybi committed
132
    # Send output to the next stage.
mshoeybi's avatar
mshoeybi committed
133
    send_to_next_pipeline_rank(output_tensor)
mshoeybi's avatar
mshoeybi committed
134

mshoeybi's avatar
mshoeybi committed
135
136
137
    # Make sure we do not allocate context memory anymore.
    if inference_params.allocate_key_value_memory:
        inference_params.allocate_key_value_memory = False
mshoeybi's avatar
mshoeybi committed
138
139


mshoeybi's avatar
mshoeybi committed
140
    return output_tensor
mshoeybi's avatar
mshoeybi committed
141
142
143



mshoeybi's avatar
mshoeybi committed
144
145
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                inference_params, recv_buffer=None):
mshoeybi's avatar
mshoeybi committed
146
    """If recv_buffer is none, we will allocate one on the fly."""
mshoeybi's avatar
mshoeybi committed
147
148
149
150
151
152
    # Run a simple forward pass.
    output_tensor = _forward_step_helper(model, tokens, position_ids,
                                         attention_mask, inference_params,
                                         recv_buffer=recv_buffer)
    # Update the sequence length offset.
    inference_params.sequence_len_offset += tokens.size(1)
mshoeybi's avatar
mshoeybi committed
153

mshoeybi's avatar
mshoeybi committed
154
155
156
    logits = None
    if mpu.is_pipeline_last_stage():
        logits = output_tensor
mshoeybi's avatar
mshoeybi committed
157

mshoeybi's avatar
mshoeybi committed
158
    return logits
mshoeybi's avatar
mshoeybi committed
159
160


mshoeybi's avatar
mshoeybi committed
161

mshoeybi's avatar
mshoeybi committed
162
163
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):
mshoeybi's avatar
mshoeybi committed
164
    """No interleaving is supported."""
mshoeybi's avatar
mshoeybi committed
165
166
    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)
mshoeybi's avatar
mshoeybi committed
167

mshoeybi's avatar
mshoeybi committed
168
169
170
171
172
    # Divide the batch dimension into micro batches.
    num_micro_batches, last_chunk = divmod(batch_size,
                                           micro_batch_size)
    if last_chunk > 0:
        num_micro_batches += 1
mshoeybi's avatar
mshoeybi committed
173

mshoeybi's avatar
mshoeybi committed
174
175
176
    # Preallocate memory for output logits.
    logits = None
    if mpu.is_pipeline_last_stage():
mshoeybi's avatar
mshoeybi committed
177
        args = get_args()
mshoeybi's avatar
mshoeybi committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        logits = torch.empty(
            (batch_size, sequence_length, args.padded_vocab_size),
            dtype=torch.float32, device=torch.cuda.current_device())

    # Preallocate recv buffer.
    recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)

    for micro_batch_index in range(num_micro_batches):
        # Slice among the batch dimenion.
        start = micro_batch_index * micro_batch_size
        end = min(start + micro_batch_size, batch_size)
        this_micro_batch_size = end - start
        tokens2use = tokens[start:end, ...]
        position_ids2use = position_ids[start:end, ...]

        # Run a simple forward pass.
        if this_micro_batch_size != micro_batch_size:
            recv_buffer = None
        output = _forward_step_helper(model, tokens2use, position_ids2use,
                                      attention_mask, inference_params,
                                      recv_buffer=recv_buffer)

        # Adjust the batch size offset to account for the micro-batch.
        inference_params.batch_size_offset += this_micro_batch_size

        # Copy logits.
        if mpu.is_pipeline_last_stage():
            logits[start:end, ...] = output

    # Once we are done with all the micro-batches, we can
    # adjust the sequence length offset.
    inference_params.sequence_len_offset += sequence_length
    # and reset the batch size offset
    inference_params.batch_size_offset = 0

    return logits