forward_step.py 7.81 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mshoeybi's avatar
mshoeybi committed
2
3
4

"""Forward step utilities."""

mshoeybi's avatar
mshoeybi committed
5
from collections.abc import Iterable
mshoeybi's avatar
mshoeybi committed
6

mshoeybi's avatar
mshoeybi committed
7
import torch
mshoeybi's avatar
mshoeybi committed
8

mshoeybi's avatar
mshoeybi committed
9
10
11
from megatron import (
    get_args,
    mpu)
mshoeybi's avatar
mshoeybi committed
12
13
14
from .communication import (
    send_to_next_pipeline_rank,
    recv_from_prev_pipeline_rank_)
mshoeybi's avatar
mshoeybi committed
15
16
17



mshoeybi's avatar
mshoeybi committed
18
class InferenceParams:
mshoeybi's avatar
mshoeybi committed
19
20
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
mshoeybi's avatar
mshoeybi committed
21

mshoeybi's avatar
mshoeybi committed
22
    def __init__(self, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
23
24
25
        """Note that offsets are set to zero and we always set the
        flag to allocate memory. After the first call, make sure to
        set this flag to False."""
mshoeybi's avatar
mshoeybi committed
26
        self.max_sequence_len = max_sequence_len
mshoeybi's avatar
mshoeybi committed
27
28
29
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
30
        self.key_value_memory_dict = {}
mshoeybi's avatar
mshoeybi committed
31

Peng Xu's avatar
Peng Xu committed
32
33
34
35
36
37
38
39
40
41
42
43
    def swap_key_value_dict(self, batch_idx):
        "swap between batches"
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")
        
        for layer_number in self.key_value_memory_dict.keys():
            inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
            assert len(batch_idx) == inference_key_memory.shape[1] ## make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_idx]
            new_inference_value_memory = inference_value_memory[:, batch_idx]
            self.key_value_memory_dict[layer_number] = (
                    new_inference_key_memory, new_inference_value_memory)
mshoeybi's avatar
mshoeybi committed
44

mshoeybi's avatar
mshoeybi committed
45
class ForwardStep:
mshoeybi's avatar
mshoeybi committed
46
47
48
    """Forward step function with all the communications.
    We use a class here to hide the inference parameters
    from the outside caller."""
mshoeybi's avatar
mshoeybi committed
49

mshoeybi's avatar
mshoeybi committed
50
    def __init__(self, model, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
51
        """Set values so we don't need to do it multiple times."""
mshoeybi's avatar
mshoeybi committed
52
        # Make sure model is in eval mode.
mshoeybi's avatar
mshoeybi committed
53
54
55
        assert not isinstance(model, Iterable), \
            'interleaving schedule is not supported for inference'
        model.eval()
mshoeybi's avatar
mshoeybi committed
56
        self.model = model
mshoeybi's avatar
mshoeybi committed
57
58
59
        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
                                                max_sequence_len)
mshoeybi's avatar
mshoeybi committed
60
61
        # Pipelining arguments.
        args = get_args()
mshoeybi's avatar
mshoeybi committed
62
63
        self.pipeline_size_larger_than_one = (
            args.pipeline_model_parallel_size > 1)
mshoeybi's avatar
mshoeybi committed
64
65
66
        # Threshold of pipelining.
        self.pipelining_batch_x_seqlen = \
            args.inference_batch_times_seqlen_threshold
mshoeybi's avatar
mshoeybi committed
67

mshoeybi's avatar
mshoeybi committed
68

mshoeybi's avatar
mshoeybi committed
69
    def __call__(self, tokens, position_ids, attention_mask):
mshoeybi's avatar
mshoeybi committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """Invocation of the forward methods. Note that self.inference_params
        is being modified by the forward step."""
        # Pipelining case.
        if self.pipeline_size_larger_than_one:
            current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
            if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
                micro_batch_size = \
                    max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
                return _with_pipelining_forward_step(self.model,
                                                     tokens,
                                                     position_ids,
                                                     attention_mask,
                                                     self.inference_params,
                                                     micro_batch_size)

        return _no_pipelining_forward_step(self.model,
                                           tokens,
                                           position_ids,
                                           attention_mask,
                                           self.inference_params)

mshoeybi's avatar
mshoeybi committed
91
92


mshoeybi's avatar
mshoeybi committed
93
94
95
96
97
def _get_recv_buffer_dtype(args):
    """Receive happens between the layers."""
    if args.fp32_residual_connection:
        return torch.float
    return args.params_dtype
mshoeybi's avatar
mshoeybi committed
98
99
100



mshoeybi's avatar
mshoeybi committed
101
102
103
104
105
106
107
108
109
def _allocate_recv_buffer(batch_size, sequence_length):
    """Receive happens between the layers with size [s, b, h]."""
    if mpu.is_pipeline_first_stage():
        return None
    args = get_args()
    recv_size = (sequence_length, batch_size, args.hidden_size)
    return torch.empty(recv_size,
                       dtype=_get_recv_buffer_dtype(args),
                       device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
110
111
112



mshoeybi's avatar
mshoeybi committed
113
114
115
116
117
118
119
120
def _forward_step_helper(model, tokens, position_ids, attention_mask,
                         inference_params, recv_buffer=None):
    """Single forward step. Update the allocate memory flag so
    only the first time the memory is allocated."""
    batch_size = tokens.size(0)
    sequence_length = tokens.size(1)
    if recv_buffer is None:
        recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
mshoeybi's avatar
mshoeybi committed
121

mshoeybi's avatar
mshoeybi committed
122
    # Receive from previous stage.
mshoeybi's avatar
mshoeybi committed
123
    recv_from_prev_pipeline_rank_(recv_buffer)
mshoeybi's avatar
mshoeybi committed
124

mshoeybi's avatar
mshoeybi committed
125
126
127
128
    # Forward pass through the model.
    model.set_input_tensor(recv_buffer)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
129

mshoeybi's avatar
mshoeybi committed
130
    # Send output to the next stage.
mshoeybi's avatar
mshoeybi committed
131
    send_to_next_pipeline_rank(output_tensor)
mshoeybi's avatar
mshoeybi committed
132

mshoeybi's avatar
mshoeybi committed
133
    return output_tensor
mshoeybi's avatar
mshoeybi committed
134
135
136



mshoeybi's avatar
mshoeybi committed
137
138
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                inference_params, recv_buffer=None):
mshoeybi's avatar
mshoeybi committed
139
    """If recv_buffer is none, we will allocate one on the fly."""
mshoeybi's avatar
mshoeybi committed
140
141
142
143
144
145
    # Run a simple forward pass.
    output_tensor = _forward_step_helper(model, tokens, position_ids,
                                         attention_mask, inference_params,
                                         recv_buffer=recv_buffer)
    # Update the sequence length offset.
    inference_params.sequence_len_offset += tokens.size(1)
mshoeybi's avatar
mshoeybi committed
146

mshoeybi's avatar
mshoeybi committed
147
148
149
    logits = None
    if mpu.is_pipeline_last_stage():
        logits = output_tensor
mshoeybi's avatar
mshoeybi committed
150

mshoeybi's avatar
mshoeybi committed
151
    return logits
mshoeybi's avatar
mshoeybi committed
152
153


mshoeybi's avatar
mshoeybi committed
154

mshoeybi's avatar
mshoeybi committed
155
156
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):
mshoeybi's avatar
mshoeybi committed
157
    """No interleaving is supported."""
mshoeybi's avatar
mshoeybi committed
158
159
    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)
mshoeybi's avatar
mshoeybi committed
160

mshoeybi's avatar
mshoeybi committed
161
162
163
164
165
    # Divide the batch dimension into micro batches.
    num_micro_batches, last_chunk = divmod(batch_size,
                                           micro_batch_size)
    if last_chunk > 0:
        num_micro_batches += 1
mshoeybi's avatar
mshoeybi committed
166

mshoeybi's avatar
mshoeybi committed
167
168
169
    # Preallocate memory for output logits.
    logits = None
    if mpu.is_pipeline_last_stage():
mshoeybi's avatar
mshoeybi committed
170
        args = get_args()
mshoeybi's avatar
mshoeybi committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        logits = torch.empty(
            (batch_size, sequence_length, args.padded_vocab_size),
            dtype=torch.float32, device=torch.cuda.current_device())

    # Preallocate recv buffer.
    recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)

    for micro_batch_index in range(num_micro_batches):
        # Slice among the batch dimenion.
        start = micro_batch_index * micro_batch_size
        end = min(start + micro_batch_size, batch_size)
        this_micro_batch_size = end - start
        tokens2use = tokens[start:end, ...]
        position_ids2use = position_ids[start:end, ...]

        # Run a simple forward pass.
        if this_micro_batch_size != micro_batch_size:
            recv_buffer = None
        output = _forward_step_helper(model, tokens2use, position_ids2use,
                                      attention_mask, inference_params,
                                      recv_buffer=recv_buffer)

        # Adjust the batch size offset to account for the micro-batch.
        inference_params.batch_size_offset += this_micro_batch_size

        # Copy logits.
        if mpu.is_pipeline_last_stage():
            logits[start:end, ...] = output

    # Once we are done with all the micro-batches, we can
    # adjust the sequence length offset.
    inference_params.sequence_len_offset += sequence_length
    # and reset the batch size offset
    inference_params.batch_size_offset = 0

    return logits