schedules.py 18.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
from contextlib import contextmanager
17
import torch
18
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
19
20

from megatron import get_args
21
from megatron import get_num_microbatches
22
23
from megatron import get_timers
from megatron import mpu
24
from megatron import p2p_communication
25
26
27
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
28

Jared Casper's avatar
Jared Casper committed
29
30
31
32
33
34
35
36
37
38
39
40
def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func


41
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
42
43
44
45
46
47
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
48
49
50
    timers = get_timers()

    timers('forward-compute').start()
51
52
53
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
54
    output_tensor, loss_func = forward_step_func(data_iterator, model)
55
    if mpu.is_pipeline_last_stage():
56
        output_tensor = loss_func(output_tensor)
57
58
59
60
61
62
63
64
65
        loss, loss_reduced = output_tensor
        output_tensor = loss / get_num_microbatches()
        losses_reduced.append(loss_reduced)
    timers('forward-compute').stop()

    return output_tensor


def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
66
67
68
69
70
71
72
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    args = get_args()

    timers = get_timers()
    timers('backward-compute').start()

    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

    # Backward pass.
    if output_tensor_grad is None:
        output_tensor = optimizer.scale_loss(output_tensor)
    torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    timers('backward-compute').stop()

    return input_tensor_grad


97
98
99
100
101
102
103
104
@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


105
106
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers, forward_only):
107
108
109
110
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

    Returns dictionary with losses."""
111
112
113
    assert len(model) == 1
    model = model[0]

114
115
116
117
    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

118
    losses_reduced = []
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    input_tensor, output_tensor_grad = None, None
    with context_handler():
        for i in range(get_num_microbatches() - 1):
            output_tensor = forward_step(forward_step_func, data_iterator, model,
                                         input_tensor, losses_reduced)
            if not forward_only:
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
    output_tensor = forward_step(forward_step_func, data_iterator, model,
                                 input_tensor, losses_reduced)
    if not forward_only:
        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
134
135
136
137
138
139

    return losses_reduced


def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
                                                  optimizer, timers, forward_only):
140
141
142
143
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
144
145
146
147
148
149
150
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
    losses_reduced = []
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

    pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
151
    pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
152
153
154
155
156
157
158
159

    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
    num_microbatches = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches = False
    if forward_only:
        num_warmup_microbatches = num_microbatches
    else:
160
161
162
163
164
165
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
166
167
168
169
170
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
171
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
172
173
174
175
            num_warmup_microbatches += (
                num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
176
177
178
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

179
    def get_model_chunk_id(microbatch_id, forward):
180
        """Helper method to get the model chunk ID given the iteration number."""
181
182
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
183
        if not forward:
184
185
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id
186

187
    def forward_step_helper(microbatch_id):
188
189
190
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
191
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
192
193
194
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_first_stage():
195
196
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
197
198
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
199
200
        output_tensor = forward_step(forward_step_func,
                                     data_iterator[model_chunk_id],
201
202
203
204
205
206
                                     model[model_chunk_id],
                                     input_tensor, losses_reduced)
        output_tensors[model_chunk_id].append(output_tensor)

        return output_tensor

207
    def backward_step_helper(microbatch_id):
208
209
210
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
211
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
212
213
214
215
216
217
218
219
220
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = \
221
222
223
224
            backward_step(optimizer,
                          input_tensor,
                          output_tensor,
                          output_tensor_grad)
225
226
227
228
229

        return input_tensor_grad

    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
230
    input_tensors[0].append(
231
        p2p_communication.recv_forward(timers))
232
233
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
234
235

        # Determine if tensor should be received from previous stage.
236
237
238
239
240
241
242
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
243
244

        # Don't send tensor downstream if on last stage.
245
246
        if mpu.is_pipeline_last_stage():
            output_tensor = None
247
248
249

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
250
251
252
253
254
255
256
        if k == (num_warmup_microbatches - 1) and not forward_only and \
                not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            input_tensor, output_tensor_grad = \
257
                p2p_communication.send_forward_backward_recv_forward_backward(
258
259
260
261
262
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
263
            input_tensor = \
264
265
                p2p_communication.send_forward_recv_forward(
                    output_tensor, recv_prev, timers)
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
        if mpu.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
        if mpu.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
304
305
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)
306
307
308
309
310
311
312
313
314
315

        recv_next = True
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
316
317
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)
318

319
320
        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
321
322
323
324
325
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
326
            p2p_communication.send_forward_backward_recv_forward_backward(
327
328
329
330
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
                    timers=timers)

331
332
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
333
334
335
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
336
337
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
338

339
    # Run cooldown backward passes (flush out pipeline).
340
341
342
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
343
                p2p_communication.recv_backward(timers))
344
345
346
347
348
349
350
351
352
353
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
354
355
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad, recv_next, timers))
356
357
358
359

    return losses_reduced


360
361
362
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
                                                     model, optimizer, timers,
                                                     forward_only):
363
364
365
366
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
367
368
    timers = get_timers()

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    assert len(model) == 1
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

    input_tensors = []
    output_tensors = []
    losses_reduced = []

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
389
        input_tensor = p2p_communication.recv_forward(timers)
390
391
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
392
        p2p_communication.send_forward(output_tensor, timers)
393
394
395
396
397
398
399
400

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
401
        input_tensor = p2p_communication.recv_forward(timers)
402
403
404
405
406
407
408
409

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
410
            p2p_communication.send_forward(output_tensor, timers)
411
        else:
412
            output_tensor_grad = \
413
414
                p2p_communication.send_forward_recv_backward(output_tensor,
                                                             timers)
415
416
417
418
419
420
421
422

        # Add input_tensor and output_tensor to end of list, then pop from the
        # start of the list for backward pass.
        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

        if forward_only:
            if not last_iteration:
423
                input_tensor = p2p_communication.recv_forward(timers)
424
425
426
427
428
429
430
431
432
        else:
            input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
433
                p2p_communication.send_backward(input_tensor_grad, timers)
434
            else:
435
                input_tensor = \
436
437
                    p2p_communication.send_backward_recv_forward(
                        input_tensor_grad, timers)
438
439
440
441
442
443
444

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

445
            output_tensor_grad = p2p_communication.recv_backward(timers)
446
447
448
449
450

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

451
            p2p_communication.send_backward(input_tensor_grad, timers)
452
453

    return losses_reduced