forward_step.py 6.68 KB
Newer Older
mshoeybi's avatar
mshoeybi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Forward step utilities."""

mshoeybi's avatar
mshoeybi committed
18
from collections.abc import Iterable
mshoeybi's avatar
mshoeybi committed
19

mshoeybi's avatar
mshoeybi committed
20
import torch
mshoeybi's avatar
mshoeybi committed
21

mshoeybi's avatar
mshoeybi committed
22
23
24
from megatron import (
    get_args,
    mpu)
mshoeybi's avatar
mshoeybi committed
25
26
27



mshoeybi's avatar
mshoeybi committed
28
class InferenceParams:
mshoeybi's avatar
mshoeybi committed
29

mshoeybi's avatar
mshoeybi committed
30

mshoeybi's avatar
mshoeybi committed
31
    def __init__(self, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
32
33

        self.max_sequence_len = max_sequence_len
mshoeybi's avatar
mshoeybi committed
34
35
36
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
mshoeybi's avatar
mshoeybi committed
37
        self.allocate_key_value_memory = True
mshoeybi's avatar
mshoeybi committed
38
39


mshoeybi's avatar
mshoeybi committed
40

mshoeybi's avatar
mshoeybi committed
41
class ForwardStep:
mshoeybi's avatar
mshoeybi committed
42

mshoeybi's avatar
mshoeybi committed
43
44
45
    def __init__(self, model, max_batch_size, max_sequence_len):

        # Make sure model is in eval mode.
mshoeybi's avatar
mshoeybi committed
46
47
48
49
50
51
52
        if isinstance(model, Iterable):
            for this_model in model:
                this_model.eval()
        else:
            model.eval()
        self.model = model

mshoeybi's avatar
mshoeybi committed
53
54
55
56
57
58
        self.constant = 512

        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
                                                max_sequence_len)

mshoeybi's avatar
mshoeybi committed
59

mshoeybi's avatar
mshoeybi committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def __call__(self, tokens, position_ids, attention_mask):
        if tokens.size(0) * tokens.size(1) >= self.constant:
            micro_batch_size = max(1, self.constant // tokens.size(1))
            return _with_pipelining_forward_step(self.model, tokens,
                                                 position_ids,
                                                 attention_mask,
                                                 self.inference_params,
                                                 micro_batch_size)
        else:
            return _no_pipelining_forward_step(self.model, tokens,
                                               position_ids,
                                               attention_mask,
                                               self.inference_params)
            
mshoeybi's avatar
mshoeybi committed
74
75


mshoeybi's avatar
mshoeybi committed
76
77
78
79
80
def _get_recv_buffer_dtype(args):
    """Receive happens between the layers."""
    if args.fp32_residual_connection:
        return torch.float
    return args.params_dtype
mshoeybi's avatar
mshoeybi committed
81
82
83



mshoeybi's avatar
mshoeybi committed
84
85
86
87
88
89
90
91
92
def _allocate_recv_buffer(batch_size, sequence_length):
    """Receive happens between the layers with size [s, b, h]."""
    if mpu.is_pipeline_first_stage():
        return None
    args = get_args()
    recv_size = (sequence_length, batch_size, args.hidden_size)
    return torch.empty(recv_size,
                       dtype=_get_recv_buffer_dtype(args),
                       device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
93
94
95



mshoeybi's avatar
mshoeybi committed
96
97
98
99
100
101
102
103
def _forward_step_helper(model, tokens, position_ids, attention_mask,
                         inference_params, recv_buffer=None):
    """Single forward step. Update the allocate memory flag so
    only the first time the memory is allocated."""
    batch_size = tokens.size(0)
    sequence_length = tokens.size(1)
    if recv_buffer is None:
        recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
mshoeybi's avatar
mshoeybi committed
104

mshoeybi's avatar
mshoeybi committed
105
106
107
108
    # Receive from previous stage.
    if not mpu.is_pipeline_first_stage():
        torch.distributed.recv(recv_buffer,
                               src=mpu.get_pipeline_model_parallel_prev_rank())
mshoeybi's avatar
mshoeybi committed
109

mshoeybi's avatar
mshoeybi committed
110
111
112
113
    # Forward pass through the model.
    model.set_input_tensor(recv_buffer)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
114

mshoeybi's avatar
mshoeybi committed
115
116
117
118
    # Send output to the next stage.
    if not mpu.is_pipeline_last_stage():
        torch.distributed.send(output_tensor,
                               mpu.get_pipeline_model_parallel_next_rank())
mshoeybi's avatar
mshoeybi committed
119

mshoeybi's avatar
mshoeybi committed
120
121
122
    # Make sure we do not allocate context memory anymore.
    if inference_params.allocate_key_value_memory:
        inference_params.allocate_key_value_memory = False
mshoeybi's avatar
mshoeybi committed
123
124


mshoeybi's avatar
mshoeybi committed
125
    return output_tensor
mshoeybi's avatar
mshoeybi committed
126
127
128



mshoeybi's avatar
mshoeybi committed
129
130
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                inference_params, recv_buffer=None):
mshoeybi's avatar
mshoeybi committed
131

mshoeybi's avatar
mshoeybi committed
132
133
134
135
136
137
    # Run a simple forward pass.
    output_tensor = _forward_step_helper(model, tokens, position_ids,
                                         attention_mask, inference_params,
                                         recv_buffer=recv_buffer)
    # Update the sequence length offset.
    inference_params.sequence_len_offset += tokens.size(1)
mshoeybi's avatar
mshoeybi committed
138

mshoeybi's avatar
mshoeybi committed
139
140
141
    logits = None
    if mpu.is_pipeline_last_stage():
        logits = output_tensor
mshoeybi's avatar
mshoeybi committed
142

mshoeybi's avatar
mshoeybi committed
143
    return logits
mshoeybi's avatar
mshoeybi committed
144
145


mshoeybi's avatar
mshoeybi committed
146
147
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):
mshoeybi's avatar
mshoeybi committed
148

mshoeybi's avatar
mshoeybi committed
149
150
    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)
mshoeybi's avatar
mshoeybi committed
151

mshoeybi's avatar
mshoeybi committed
152
153
154
155
156
    # Divide the batch dimension into micro batches.
    num_micro_batches, last_chunk = divmod(batch_size,
                                           micro_batch_size)
    if last_chunk > 0:
        num_micro_batches += 1
mshoeybi's avatar
mshoeybi committed
157

mshoeybi's avatar
mshoeybi committed
158
159
160
    # Preallocate memory for output logits.
    logits = None
    if mpu.is_pipeline_last_stage():
mshoeybi's avatar
mshoeybi committed
161
        args = get_args()
mshoeybi's avatar
mshoeybi committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        logits = torch.empty(
            (batch_size, sequence_length, args.padded_vocab_size),
            dtype=torch.float32, device=torch.cuda.current_device())

    # Preallocate recv buffer.
    recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)

    for micro_batch_index in range(num_micro_batches):
        # Slice among the batch dimenion.
        start = micro_batch_index * micro_batch_size
        end = min(start + micro_batch_size, batch_size)
        this_micro_batch_size = end - start
        tokens2use = tokens[start:end, ...]
        position_ids2use = position_ids[start:end, ...]

        # Run a simple forward pass.
        if this_micro_batch_size != micro_batch_size:
            recv_buffer = None
        output = _forward_step_helper(model, tokens2use, position_ids2use,
                                      attention_mask, inference_params,
                                      recv_buffer=recv_buffer)

        # Adjust the batch size offset to account for the micro-batch.
        inference_params.batch_size_offset += this_micro_batch_size

        # Copy logits.
        if mpu.is_pipeline_last_stage():
            logits[start:end, ...] = output

    # Once we are done with all the micro-batches, we can
    # adjust the sequence length offset.
    inference_params.sequence_len_offset += sequence_length
    # and reset the batch size offset
    inference_params.batch_size_offset = 0

    return logits