forward_step.py 8.36 KB
Newer Older
mshoeybi's avatar
mshoeybi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Forward step utilities."""

mshoeybi's avatar
mshoeybi committed
18
from collections.abc import Iterable
mshoeybi's avatar
mshoeybi committed
19

mshoeybi's avatar
mshoeybi committed
20
import torch
mshoeybi's avatar
mshoeybi committed
21

mshoeybi's avatar
mshoeybi committed
22
23
24
from megatron import (
    get_args,
    mpu)
mshoeybi's avatar
mshoeybi committed
25
26
27
from .communication import (
    send_to_next_pipeline_rank,
    recv_from_prev_pipeline_rank_)
mshoeybi's avatar
mshoeybi committed
28
29
30



mshoeybi's avatar
mshoeybi committed
31
class InferenceParams:
mshoeybi's avatar
mshoeybi committed
32
33
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
mshoeybi's avatar
mshoeybi committed
34

mshoeybi's avatar
mshoeybi committed
35
    def __init__(self, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
36
37
38
        """Note that offsets are set to zero and we always set the
        flag to allocate memory. After the first call, make sure to
        set this flag to False."""
mshoeybi's avatar
mshoeybi committed
39
        self.max_sequence_len = max_sequence_len
mshoeybi's avatar
mshoeybi committed
40
41
42
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
43
        self.key_value_memory_dict = {}
mshoeybi's avatar
mshoeybi committed
44

Peng Xu's avatar
Peng Xu committed
45
46
47
48
49
50
51
52
53
54
55
56
    def swap_key_value_dict(self, batch_idx):
        "swap between batches"
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")
        
        for layer_number in self.key_value_memory_dict.keys():
            inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
            assert len(batch_idx) == inference_key_memory.shape[1] ## make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_idx]
            new_inference_value_memory = inference_value_memory[:, batch_idx]
            self.key_value_memory_dict[layer_number] = (
                    new_inference_key_memory, new_inference_value_memory)
mshoeybi's avatar
mshoeybi committed
57

mshoeybi's avatar
mshoeybi committed
58
class ForwardStep:
mshoeybi's avatar
mshoeybi committed
59
60
61
    """Forward step function with all the communications.
    We use a class here to hide the inference parameters
    from the outside caller."""
mshoeybi's avatar
mshoeybi committed
62

mshoeybi's avatar
mshoeybi committed
63
    def __init__(self, model, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
64
        """Set values so we don't need to do it multiple times."""
mshoeybi's avatar
mshoeybi committed
65
        # Make sure model is in eval mode.
mshoeybi's avatar
mshoeybi committed
66
67
68
        assert not isinstance(model, Iterable), \
            'interleaving schedule is not supported for inference'
        model.eval()
mshoeybi's avatar
mshoeybi committed
69
        self.model = model
mshoeybi's avatar
mshoeybi committed
70
71
72
        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
                                                max_sequence_len)
mshoeybi's avatar
mshoeybi committed
73
74
        # Pipelining arguments.
        args = get_args()
mshoeybi's avatar
mshoeybi committed
75
76
        self.pipeline_size_larger_than_one = (
            args.pipeline_model_parallel_size > 1)
mshoeybi's avatar
mshoeybi committed
77
78
79
        # Threshold of pipelining.
        self.pipelining_batch_x_seqlen = \
            args.inference_batch_times_seqlen_threshold
mshoeybi's avatar
mshoeybi committed
80

mshoeybi's avatar
mshoeybi committed
81

mshoeybi's avatar
mshoeybi committed
82
    def __call__(self, tokens, position_ids, attention_mask):
mshoeybi's avatar
mshoeybi committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        """Invocation of the forward methods. Note that self.inference_params
        is being modified by the forward step."""
        # Pipelining case.
        if self.pipeline_size_larger_than_one:
            current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
            if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
                micro_batch_size = \
                    max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
                return _with_pipelining_forward_step(self.model,
                                                     tokens,
                                                     position_ids,
                                                     attention_mask,
                                                     self.inference_params,
                                                     micro_batch_size)

        return _no_pipelining_forward_step(self.model,
                                           tokens,
                                           position_ids,
                                           attention_mask,
                                           self.inference_params)

mshoeybi's avatar
mshoeybi committed
104
105


mshoeybi's avatar
mshoeybi committed
106
107
108
109
110
def _get_recv_buffer_dtype(args):
    """Receive happens between the layers."""
    if args.fp32_residual_connection:
        return torch.float
    return args.params_dtype
mshoeybi's avatar
mshoeybi committed
111
112
113



mshoeybi's avatar
mshoeybi committed
114
115
116
117
118
119
120
121
122
def _allocate_recv_buffer(batch_size, sequence_length):
    """Receive happens between the layers with size [s, b, h]."""
    if mpu.is_pipeline_first_stage():
        return None
    args = get_args()
    recv_size = (sequence_length, batch_size, args.hidden_size)
    return torch.empty(recv_size,
                       dtype=_get_recv_buffer_dtype(args),
                       device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
123
124
125



mshoeybi's avatar
mshoeybi committed
126
127
128
129
130
131
132
133
def _forward_step_helper(model, tokens, position_ids, attention_mask,
                         inference_params, recv_buffer=None):
    """Single forward step. Update the allocate memory flag so
    only the first time the memory is allocated."""
    batch_size = tokens.size(0)
    sequence_length = tokens.size(1)
    if recv_buffer is None:
        recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
mshoeybi's avatar
mshoeybi committed
134

mshoeybi's avatar
mshoeybi committed
135
    # Receive from previous stage.
mshoeybi's avatar
mshoeybi committed
136
    recv_from_prev_pipeline_rank_(recv_buffer)
mshoeybi's avatar
mshoeybi committed
137

mshoeybi's avatar
mshoeybi committed
138
139
140
141
    # Forward pass through the model.
    model.set_input_tensor(recv_buffer)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
142

mshoeybi's avatar
mshoeybi committed
143
    # Send output to the next stage.
mshoeybi's avatar
mshoeybi committed
144
    send_to_next_pipeline_rank(output_tensor)
mshoeybi's avatar
mshoeybi committed
145

mshoeybi's avatar
mshoeybi committed
146
    return output_tensor
mshoeybi's avatar
mshoeybi committed
147
148
149



mshoeybi's avatar
mshoeybi committed
150
151
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                inference_params, recv_buffer=None):
mshoeybi's avatar
mshoeybi committed
152
    """If recv_buffer is none, we will allocate one on the fly."""
mshoeybi's avatar
mshoeybi committed
153
154
155
156
157
158
    # Run a simple forward pass.
    output_tensor = _forward_step_helper(model, tokens, position_ids,
                                         attention_mask, inference_params,
                                         recv_buffer=recv_buffer)
    # Update the sequence length offset.
    inference_params.sequence_len_offset += tokens.size(1)
mshoeybi's avatar
mshoeybi committed
159

mshoeybi's avatar
mshoeybi committed
160
161
162
    logits = None
    if mpu.is_pipeline_last_stage():
        logits = output_tensor
mshoeybi's avatar
mshoeybi committed
163

mshoeybi's avatar
mshoeybi committed
164
    return logits
mshoeybi's avatar
mshoeybi committed
165
166


mshoeybi's avatar
mshoeybi committed
167

mshoeybi's avatar
mshoeybi committed
168
169
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):
mshoeybi's avatar
mshoeybi committed
170
    """No interleaving is supported."""
mshoeybi's avatar
mshoeybi committed
171
172
    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)
mshoeybi's avatar
mshoeybi committed
173

mshoeybi's avatar
mshoeybi committed
174
175
176
177
178
    # Divide the batch dimension into micro batches.
    num_micro_batches, last_chunk = divmod(batch_size,
                                           micro_batch_size)
    if last_chunk > 0:
        num_micro_batches += 1
mshoeybi's avatar
mshoeybi committed
179

mshoeybi's avatar
mshoeybi committed
180
181
182
    # Preallocate memory for output logits.
    logits = None
    if mpu.is_pipeline_last_stage():
mshoeybi's avatar
mshoeybi committed
183
        args = get_args()
mshoeybi's avatar
mshoeybi committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        logits = torch.empty(
            (batch_size, sequence_length, args.padded_vocab_size),
            dtype=torch.float32, device=torch.cuda.current_device())

    # Preallocate recv buffer.
    recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)

    for micro_batch_index in range(num_micro_batches):
        # Slice among the batch dimenion.
        start = micro_batch_index * micro_batch_size
        end = min(start + micro_batch_size, batch_size)
        this_micro_batch_size = end - start
        tokens2use = tokens[start:end, ...]
        position_ids2use = position_ids[start:end, ...]

        # Run a simple forward pass.
        if this_micro_batch_size != micro_batch_size:
            recv_buffer = None
        output = _forward_step_helper(model, tokens2use, position_ids2use,
                                      attention_mask, inference_params,
                                      recv_buffer=recv_buffer)

        # Adjust the batch size offset to account for the micro-batch.
        inference_params.batch_size_offset += this_micro_batch_size

        # Copy logits.
        if mpu.is_pipeline_last_stage():
            logits[start:end, ...] = output

    # Once we are done with all the micro-batches, we can
    # adjust the sequence length offset.
    inference_params.sequence_len_offset += sequence_length
    # and reset the batch size offset
    inference_params.batch_size_offset = 0

    return logits