forward_step.py 7.82 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mshoeybi's avatar
mshoeybi committed
2
3
4

"""Forward step utilities."""

mshoeybi's avatar
mshoeybi committed
5
from collections.abc import Iterable
mshoeybi's avatar
mshoeybi committed
6

mshoeybi's avatar
mshoeybi committed
7
import torch
mshoeybi's avatar
mshoeybi committed
8

9
10
from megatron import get_args
from megatron.core import mpu
mshoeybi's avatar
mshoeybi committed
11
12
13
from .communication import (
    send_to_next_pipeline_rank,
    recv_from_prev_pipeline_rank_)
mshoeybi's avatar
mshoeybi committed
14
15
16



mshoeybi's avatar
mshoeybi committed
17
class InferenceParams:
mshoeybi's avatar
mshoeybi committed
18
19
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""
mshoeybi's avatar
mshoeybi committed
20

mshoeybi's avatar
mshoeybi committed
21
    def __init__(self, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
22
23
24
        """Note that offsets are set to zero and we always set the
        flag to allocate memory. After the first call, make sure to
        set this flag to False."""
mshoeybi's avatar
mshoeybi committed
25
        self.max_sequence_len = max_sequence_len
mshoeybi's avatar
mshoeybi committed
26
27
28
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
29
        self.key_value_memory_dict = {}
mshoeybi's avatar
mshoeybi committed
30

Peng Xu's avatar
Peng Xu committed
31
32
33
34
35
36
37
38
39
40
41
42
    def swap_key_value_dict(self, batch_idx):
        "swap between batches"
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")
        
        for layer_number in self.key_value_memory_dict.keys():
            inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
            assert len(batch_idx) == inference_key_memory.shape[1] ## make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_idx]
            new_inference_value_memory = inference_value_memory[:, batch_idx]
            self.key_value_memory_dict[layer_number] = (
                    new_inference_key_memory, new_inference_value_memory)
mshoeybi's avatar
mshoeybi committed
43

mshoeybi's avatar
mshoeybi committed
44
class ForwardStep:
mshoeybi's avatar
mshoeybi committed
45
46
47
    """Forward step function with all the communications.
    We use a class here to hide the inference parameters
    from the outside caller."""
mshoeybi's avatar
mshoeybi committed
48

mshoeybi's avatar
mshoeybi committed
49
    def __init__(self, model, max_batch_size, max_sequence_len):
mshoeybi's avatar
mshoeybi committed
50
        """Set values so we don't need to do it multiple times."""
mshoeybi's avatar
mshoeybi committed
51
        # Make sure model is in eval mode.
mshoeybi's avatar
mshoeybi committed
52
53
54
        assert not isinstance(model, Iterable), \
            'interleaving schedule is not supported for inference'
        model.eval()
mshoeybi's avatar
mshoeybi committed
55
        self.model = model
mshoeybi's avatar
mshoeybi committed
56
57
58
        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
                                                max_sequence_len)
mshoeybi's avatar
mshoeybi committed
59
60
        # Pipelining arguments.
        args = get_args()
mshoeybi's avatar
mshoeybi committed
61
62
        self.pipeline_size_larger_than_one = (
            args.pipeline_model_parallel_size > 1)
mshoeybi's avatar
mshoeybi committed
63
64
65
        # Threshold of pipelining.
        self.pipelining_batch_x_seqlen = \
            args.inference_batch_times_seqlen_threshold
mshoeybi's avatar
mshoeybi committed
66

mshoeybi's avatar
mshoeybi committed
67

mshoeybi's avatar
mshoeybi committed
68
    def __call__(self, tokens, position_ids, attention_mask):
mshoeybi's avatar
mshoeybi committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        """Invocation of the forward methods. Note that self.inference_params
        is being modified by the forward step."""
        # Pipelining case.
        if self.pipeline_size_larger_than_one:
            current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
            if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
                micro_batch_size = \
                    max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
                return _with_pipelining_forward_step(self.model,
                                                     tokens,
                                                     position_ids,
                                                     attention_mask,
                                                     self.inference_params,
                                                     micro_batch_size)

        return _no_pipelining_forward_step(self.model,
                                           tokens,
                                           position_ids,
                                           attention_mask,
                                           self.inference_params)

mshoeybi's avatar
mshoeybi committed
90
91


mshoeybi's avatar
mshoeybi committed
92
93
94
95
96
def _get_recv_buffer_dtype(args):
    """Receive happens between the layers."""
    if args.fp32_residual_connection:
        return torch.float
    return args.params_dtype
mshoeybi's avatar
mshoeybi committed
97
98
99



mshoeybi's avatar
mshoeybi committed
100
101
102
103
104
105
106
107
108
def _allocate_recv_buffer(batch_size, sequence_length):
    """Receive happens between the layers with size [s, b, h]."""
    if mpu.is_pipeline_first_stage():
        return None
    args = get_args()
    recv_size = (sequence_length, batch_size, args.hidden_size)
    return torch.empty(recv_size,
                       dtype=_get_recv_buffer_dtype(args),
                       device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
109
110
111



mshoeybi's avatar
mshoeybi committed
112
113
114
115
116
117
118
119
def _forward_step_helper(model, tokens, position_ids, attention_mask,
                         inference_params, recv_buffer=None):
    """Single forward step. Update the allocate memory flag so
    only the first time the memory is allocated."""
    batch_size = tokens.size(0)
    sequence_length = tokens.size(1)
    if recv_buffer is None:
        recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
mshoeybi's avatar
mshoeybi committed
120

mshoeybi's avatar
mshoeybi committed
121
    # Receive from previous stage.
mshoeybi's avatar
mshoeybi committed
122
    recv_from_prev_pipeline_rank_(recv_buffer)
mshoeybi's avatar
mshoeybi committed
123

mshoeybi's avatar
mshoeybi committed
124
125
126
127
    # Forward pass through the model.
    model.set_input_tensor(recv_buffer)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
128

mshoeybi's avatar
mshoeybi committed
129
    # Send output to the next stage.
mshoeybi's avatar
mshoeybi committed
130
    send_to_next_pipeline_rank(output_tensor)
mshoeybi's avatar
mshoeybi committed
131

mshoeybi's avatar
mshoeybi committed
132
    return output_tensor
mshoeybi's avatar
mshoeybi committed
133
134
135



mshoeybi's avatar
mshoeybi committed
136
137
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                inference_params, recv_buffer=None):
mshoeybi's avatar
mshoeybi committed
138
    """If recv_buffer is none, we will allocate one on the fly."""
mshoeybi's avatar
mshoeybi committed
139
140
141
142
143
144
    # Run a simple forward pass.
    output_tensor = _forward_step_helper(model, tokens, position_ids,
                                         attention_mask, inference_params,
                                         recv_buffer=recv_buffer)
    # Update the sequence length offset.
    inference_params.sequence_len_offset += tokens.size(1)
mshoeybi's avatar
mshoeybi committed
145

mshoeybi's avatar
mshoeybi committed
146
147
148
    logits = None
    if mpu.is_pipeline_last_stage():
        logits = output_tensor
mshoeybi's avatar
mshoeybi committed
149

mshoeybi's avatar
mshoeybi committed
150
    return logits
mshoeybi's avatar
mshoeybi committed
151
152


mshoeybi's avatar
mshoeybi committed
153

mshoeybi's avatar
mshoeybi committed
154
155
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):
mshoeybi's avatar
mshoeybi committed
156
    """No interleaving is supported."""
mshoeybi's avatar
mshoeybi committed
157
158
    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)
mshoeybi's avatar
mshoeybi committed
159

mshoeybi's avatar
mshoeybi committed
160
161
162
163
164
    # Divide the batch dimension into micro batches.
    num_micro_batches, last_chunk = divmod(batch_size,
                                           micro_batch_size)
    if last_chunk > 0:
        num_micro_batches += 1
mshoeybi's avatar
mshoeybi committed
165

mshoeybi's avatar
mshoeybi committed
166
167
168
    # Preallocate memory for output logits.
    logits = None
    if mpu.is_pipeline_last_stage():
mshoeybi's avatar
mshoeybi committed
169
        args = get_args()
mshoeybi's avatar
mshoeybi committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        logits = torch.empty(
            (batch_size, sequence_length, args.padded_vocab_size),
            dtype=torch.float32, device=torch.cuda.current_device())

    # Preallocate recv buffer.
    recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)

    for micro_batch_index in range(num_micro_batches):
        # Slice among the batch dimenion.
        start = micro_batch_index * micro_batch_size
        end = min(start + micro_batch_size, batch_size)
        this_micro_batch_size = end - start
        tokens2use = tokens[start:end, ...]
        position_ids2use = position_ids[start:end, ...]

        # Run a simple forward pass.
        if this_micro_batch_size != micro_batch_size:
            recv_buffer = None
        output = _forward_step_helper(model, tokens2use, position_ids2use,
                                      attention_mask, inference_params,
                                      recv_buffer=recv_buffer)

        # Adjust the batch size offset to account for the micro-batch.
        inference_params.batch_size_offset += this_micro_batch_size

        # Copy logits.
        if mpu.is_pipeline_last_stage():
            logits[start:end, ...] = output

    # Once we are done with all the micro-batches, we can
    # adjust the sequence length offset.
    inference_params.sequence_len_offset += sequence_length
    # and reset the batch size offset
    inference_params.batch_size_offset = 0

    return logits