forward_step.py 6.52 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mshoeybi's avatar
mshoeybi committed
2
3
4

"""Forward step utilities."""

mshoeybi's avatar
mshoeybi committed
5
from collections.abc import Iterable
mshoeybi's avatar
mshoeybi committed
6

mshoeybi's avatar
mshoeybi committed
7
import torch
mshoeybi's avatar
mshoeybi committed
8

9
from megatron import get_args
liangjing's avatar
v1  
liangjing committed
10
from megatron.core import mpu, InferenceParams
mshoeybi's avatar
mshoeybi committed
11
12
13
from .communication import (
    send_to_next_pipeline_rank,
    recv_from_prev_pipeline_rank_)
mshoeybi's avatar
mshoeybi committed
14
15


mshoeybi's avatar
mshoeybi committed
16
class ForwardStep:
mshoeybi's avatar
mshoeybi committed
17
18
19
    """Forward step function with all the communications.
    We use a class here to hide the inference parameters
    from the outside caller."""
mshoeybi's avatar
mshoeybi committed
20

liangjing's avatar
v1  
liangjing committed
21
    def __init__(self, model, max_batch_size, max_sequence_length):
mshoeybi's avatar
mshoeybi committed
22
        """Set values so we don't need to do it multiple times."""
mshoeybi's avatar
mshoeybi committed
23
        # Make sure model is in eval mode.
mshoeybi's avatar
mshoeybi committed
24
25
26
        assert not isinstance(model, Iterable), \
            'interleaving schedule is not supported for inference'
        model.eval()
mshoeybi's avatar
mshoeybi committed
27
        self.model = model
mshoeybi's avatar
mshoeybi committed
28
29
        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
liangjing's avatar
v1  
liangjing committed
30
                                                max_sequence_length)
mshoeybi's avatar
mshoeybi committed
31
32
        # Pipelining arguments.
        args = get_args()
mshoeybi's avatar
mshoeybi committed
33
34
        self.pipeline_size_larger_than_one = (
            args.pipeline_model_parallel_size > 1)
mshoeybi's avatar
mshoeybi committed
35
36
37
        # Threshold of pipelining.
        self.pipelining_batch_x_seqlen = \
            args.inference_batch_times_seqlen_threshold
mshoeybi's avatar
mshoeybi committed
38

mshoeybi's avatar
mshoeybi committed
39

mshoeybi's avatar
mshoeybi committed
40
    def __call__(self, tokens, position_ids, attention_mask):
mshoeybi's avatar
mshoeybi committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        """Invocation of the forward methods. Note that self.inference_params
        is being modified by the forward step."""
        # Pipelining case.
        if self.pipeline_size_larger_than_one:
            current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
            if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
                micro_batch_size = \
                    max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
                return _with_pipelining_forward_step(self.model,
                                                     tokens,
                                                     position_ids,
                                                     attention_mask,
                                                     self.inference_params,
                                                     micro_batch_size)

        return _no_pipelining_forward_step(self.model,
                                           tokens,
                                           position_ids,
                                           attention_mask,
                                           self.inference_params)

mshoeybi's avatar
mshoeybi committed
62
63


mshoeybi's avatar
mshoeybi committed
64
65
66
67
68
def _get_recv_buffer_dtype(args):
    """Receive happens between the layers."""
    if args.fp32_residual_connection:
        return torch.float
    return args.params_dtype
mshoeybi's avatar
mshoeybi committed
69
70
71



mshoeybi's avatar
mshoeybi committed
72
73
74
75
76
77
78
79
80
def _allocate_recv_buffer(batch_size, sequence_length):
    """Receive happens between the layers with size [s, b, h]."""
    if mpu.is_pipeline_first_stage():
        return None
    args = get_args()
    recv_size = (sequence_length, batch_size, args.hidden_size)
    return torch.empty(recv_size,
                       dtype=_get_recv_buffer_dtype(args),
                       device=torch.cuda.current_device())
mshoeybi's avatar
mshoeybi committed
81
82
83



mshoeybi's avatar
mshoeybi committed
84
85
86
87
88
89
90
91
def _forward_step_helper(model, tokens, position_ids, attention_mask,
                         inference_params, recv_buffer=None):
    """Single forward step. Update the allocate memory flag so
    only the first time the memory is allocated."""
    batch_size = tokens.size(0)
    sequence_length = tokens.size(1)
    if recv_buffer is None:
        recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
mshoeybi's avatar
mshoeybi committed
92

mshoeybi's avatar
mshoeybi committed
93
    # Receive from previous stage.
mshoeybi's avatar
mshoeybi committed
94
    recv_from_prev_pipeline_rank_(recv_buffer)
mshoeybi's avatar
mshoeybi committed
95

mshoeybi's avatar
mshoeybi committed
96
97
98
99
    # Forward pass through the model.
    model.set_input_tensor(recv_buffer)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
100

mshoeybi's avatar
mshoeybi committed
101
    # Send output to the next stage.
mshoeybi's avatar
mshoeybi committed
102
    send_to_next_pipeline_rank(output_tensor)
mshoeybi's avatar
mshoeybi committed
103

mshoeybi's avatar
mshoeybi committed
104
    return output_tensor
mshoeybi's avatar
mshoeybi committed
105
106
107



mshoeybi's avatar
mshoeybi committed
108
109
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                inference_params, recv_buffer=None):
mshoeybi's avatar
mshoeybi committed
110
    """If recv_buffer is none, we will allocate one on the fly."""
mshoeybi's avatar
mshoeybi committed
111
112
113
114
115
116
    # Run a simple forward pass.
    output_tensor = _forward_step_helper(model, tokens, position_ids,
                                         attention_mask, inference_params,
                                         recv_buffer=recv_buffer)
    # Update the sequence length offset.
    inference_params.sequence_len_offset += tokens.size(1)
mshoeybi's avatar
mshoeybi committed
117

mshoeybi's avatar
mshoeybi committed
118
119
120
    logits = None
    if mpu.is_pipeline_last_stage():
        logits = output_tensor
mshoeybi's avatar
mshoeybi committed
121

mshoeybi's avatar
mshoeybi committed
122
    return logits
mshoeybi's avatar
mshoeybi committed
123
124


mshoeybi's avatar
mshoeybi committed
125

mshoeybi's avatar
mshoeybi committed
126
127
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):
mshoeybi's avatar
mshoeybi committed
128
    """No interleaving is supported."""
mshoeybi's avatar
mshoeybi committed
129
130
    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)
mshoeybi's avatar
mshoeybi committed
131

mshoeybi's avatar
mshoeybi committed
132
133
134
135
136
    # Divide the batch dimension into micro batches.
    num_micro_batches, last_chunk = divmod(batch_size,
                                           micro_batch_size)
    if last_chunk > 0:
        num_micro_batches += 1
mshoeybi's avatar
mshoeybi committed
137

mshoeybi's avatar
mshoeybi committed
138
139
140
    # Preallocate memory for output logits.
    logits = None
    if mpu.is_pipeline_last_stage():
mshoeybi's avatar
mshoeybi committed
141
        args = get_args()
mshoeybi's avatar
mshoeybi committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        logits = torch.empty(
            (batch_size, sequence_length, args.padded_vocab_size),
            dtype=torch.float32, device=torch.cuda.current_device())

    # Preallocate recv buffer.
    recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)

    for micro_batch_index in range(num_micro_batches):
        # Slice among the batch dimenion.
        start = micro_batch_index * micro_batch_size
        end = min(start + micro_batch_size, batch_size)
        this_micro_batch_size = end - start
        tokens2use = tokens[start:end, ...]
        position_ids2use = position_ids[start:end, ...]

        # Run a simple forward pass.
        if this_micro_batch_size != micro_batch_size:
            recv_buffer = None
        output = _forward_step_helper(model, tokens2use, position_ids2use,
                                      attention_mask, inference_params,
                                      recv_buffer=recv_buffer)

        # Adjust the batch size offset to account for the micro-batch.
        inference_params.batch_size_offset += this_micro_batch_size

        # Copy logits.
        if mpu.is_pipeline_last_stage():
            logits[start:end, ...] = output

    # Once we are done with all the micro-batches, we can
    # adjust the sequence length offset.
    inference_params.sequence_len_offset += sequence_length
    # and reset the batch size offset
    inference_params.batch_size_offset = 0

    return logits