schedules.py 19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
from contextlib import contextmanager
17
import torch
18
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
19
20

from megatron import get_args
21
from megatron import get_num_microbatches
22
23
from megatron import get_timers
from megatron import mpu
24
from megatron import p2p_communication
25
26
27
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
28

Jared Casper's avatar
Jared Casper committed
29
30
31
32
33
def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
34
35
36
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
Jared Casper's avatar
Jared Casper committed
37
38
39
40
41
42
43
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func


44
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
45
46
47
48
49
50
    """Forward step for passed-in model.

    If first stage, input tensor is obtained from data_iterator, otherwise
    passed-in input_tensor is used.

    Returns output tensor."""
51
52
53
    timers = get_timers()

    timers('forward-compute').start()
54
55
56
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
57
    output_tensor, loss_func = forward_step_func(data_iterator, model)
58
    if mpu.is_pipeline_last_stage():
59
        output_tensor = loss_func(output_tensor)
60
61
62
63
64
65
66
67
68
        loss, loss_reduced = output_tensor
        output_tensor = loss / get_num_microbatches()
        losses_reduced.append(loss_reduced)
    timers('forward-compute').stop()

    return output_tensor


def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
69
70
71
72
73
74
75
    """Backward step through passed-in output tensor.

    If last stage, output_tensor_grad is None, otherwise gradient of loss
    with respect to stage's output tensor.

    Returns gradient of loss with respect to input tensor (None if first
    stage)."""
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    args = get_args()

    timers = get_timers()
    timers('backward-compute').start()

    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

    # Backward pass.
    if output_tensor_grad is None:
        output_tensor = optimizer.scale_loss(output_tensor)
    torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    timers('backward-compute').stop()

    return input_tensor_grad


100
101
102
103
104
105
106
107
@contextmanager
def dummy_handler():
    try:
        yield
    finally:
        pass


108
109
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers, forward_only):
110
111
112
113
    """Run forward and backward passes with no pipeline parallelism
    (no inter-stage communication).

    Returns dictionary with losses."""
114
115
116
    assert len(model) == 1
    model = model[0]

117
118
119
120
    context_handler = dummy_handler
    if isinstance(model, torchDDP):
        context_handler = model.no_sync

121
    losses_reduced = []
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    input_tensor, output_tensor_grad = None, None
    with context_handler():
        for i in range(get_num_microbatches() - 1):
            output_tensor = forward_step(forward_step_func, data_iterator, model,
                                         input_tensor, losses_reduced)
            if not forward_only:
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

    # Run computation for last microbatch out of context handler (want to
    # synchronize gradients).
    output_tensor = forward_step(forward_step_func, data_iterator, model,
                                 input_tensor, losses_reduced)
    if not forward_only:
        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
137
138
139
140
141
142

    return losses_reduced


def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
                                                  optimizer, timers, forward_only):
143
144
145
146
    """Run interleaved 1F1B schedule (model split into model chunks), with
    communication between pipeline stages as needed.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
147
148
149
150
151
152
153
    input_tensors = [[] for _ in range(len(model))]
    output_tensors = [[] for _ in range(len(model))]
    losses_reduced = []
    if not forward_only:
        output_tensor_grads = [[] for _ in range(len(model))]

    pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
154
    pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
155
156
157
158
159
160
161
162

    # Compute number of warmup and remaining microbatches.
    num_model_chunks = len(model)
    num_microbatches = get_num_microbatches() * num_model_chunks
    all_warmup_microbatches = False
    if forward_only:
        num_warmup_microbatches = num_microbatches
    else:
163
164
165
166
167
168
        # Run all forward passes and then all backward passes if number of
        # microbatches is just the number of pipeline stages.
        # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
        # all workers, followed by more microbatches after depending on
        # stage ID (more forward passes for earlier stages, later stages can
        # immediately start with 1F1B).
169
170
171
172
173
        if get_num_microbatches() == pipeline_parallel_size:
            num_warmup_microbatches = num_microbatches
            all_warmup_microbatches = True
        else:
            num_warmup_microbatches = \
174
                (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
175
176
177
178
            num_warmup_microbatches += (
                num_model_chunks - 1) * pipeline_parallel_size
            num_warmup_microbatches = min(num_warmup_microbatches,
                                          num_microbatches)
179
180
181
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

182
    def get_model_chunk_id(microbatch_id, forward):
183
        """Helper method to get the model chunk ID given the iteration number."""
184
185
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
186
        if not forward:
187
188
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id
189

190
    def forward_step_helper(microbatch_id):
191
192
193
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
194
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
195
196
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

197
        # forward step
198
        if mpu.is_pipeline_first_stage():
199
200
            if len(input_tensors[model_chunk_id]) == \
                    len(output_tensors[model_chunk_id]):
201
202
                input_tensors[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id][-1]
203
204
        output_tensor = forward_step(forward_step_func,
                                     data_iterator[model_chunk_id],
205
206
207
208
                                     model[model_chunk_id],
                                     input_tensor, losses_reduced)
        output_tensors[model_chunk_id].append(output_tensor)

209
210
211
212
213
        # if forward-only, no need to save tensors for a backward pass
        if forward_only:
            input_tensors[model_chunk_id].pop()
            output_tensors[model_chunk_id].pop()

214
215
        return output_tensor

216
    def backward_step_helper(microbatch_id):
217
218
219
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
220
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
221
222
223
224
225
226
227
228
229
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_last_stage():
            if len(output_tensor_grads[model_chunk_id]) == 0:
                output_tensor_grads[model_chunk_id].append(None)
        input_tensor = input_tensors[model_chunk_id].pop(0)
        output_tensor = output_tensors[model_chunk_id].pop(0)
        output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
        input_tensor_grad = \
230
231
232
233
            backward_step(optimizer,
                          input_tensor,
                          output_tensor,
                          output_tensor_grad)
234
235
236
237
238

        return input_tensor_grad

    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
239
    input_tensors[0].append(
240
        p2p_communication.recv_forward(timers=timers))
241
242
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
243
244

        # Determine if tensor should be received from previous stage.
245
246
247
248
249
250
251
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            if next_forward_model_chunk_id == 0:
                recv_prev = False
        if k == (num_microbatches - 1):
            recv_prev = False
252
253

        # Don't send tensor downstream if on last stage.
254
255
        if mpu.is_pipeline_last_stage():
            output_tensor = None
256
257
258

        # Send and receive tensors as appropriate (send tensors computed
        # in this iteration; receive tensors for next iteration).
259
260
261
262
263
264
265
        if k == (num_warmup_microbatches - 1) and not forward_only and \
                not all_warmup_microbatches:
            input_tensor_grad = None
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            input_tensor, output_tensor_grad = \
266
                p2p_communication.send_forward_backward_recv_forward_backward(
267
268
269
270
271
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
272
            input_tensor = \
273
                p2p_communication.send_forward_recv_forward(
274
                    output_tensor, recv_prev=recv_prev, timers=timers)
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
    for k in range(num_microbatches_remaining):
        # Forward pass.
        forward_k = k + num_warmup_microbatches
        output_tensor = forward_step_helper(forward_k)

        # Backward pass.
        backward_k = k
        input_tensor_grad = backward_step_helper(backward_k)

        # Send output_tensor and input_tensor_grad, receive input_tensor
        # and output_tensor_grad.

        # Determine if current stage has anything to send in either direction,
        # otherwise set tensor to None.
        forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
        if mpu.is_pipeline_last_stage():
            output_tensor = None

        backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
        mpu.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
        if mpu.is_pipeline_first_stage():
            input_tensor_grad = None

        # Determine if peers are sending, and where in data structure to put
        # received tensors.
        recv_prev = True
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            # First stage is ahead of last stage by (pipeline_parallel_size - 1).
            next_forward_model_chunk_id = get_model_chunk_id(
                forward_k - (pipeline_parallel_size - 1), forward=True)
            if next_forward_model_chunk_id == (num_model_chunks - 1):
                recv_prev = False
            next_forward_model_chunk_id += 1
        else:
313
314
            next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                             forward=True)
315
316
317
318
319
320
321
322
323
324

        recv_next = True
        if mpu.is_pipeline_last_stage(ignore_virtual=True):
            # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
            next_backward_model_chunk_id = get_model_chunk_id(
                backward_k - (pipeline_parallel_size - 1), forward=False)
            if next_backward_model_chunk_id == 0:
                recv_next = False
            next_backward_model_chunk_id -= 1
        else:
325
326
            next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                              forward=False)
327

328
329
        # If last iteration, don't receive; we already received one extra
        # before the start of the for loop.
330
331
332
333
334
        if k == (num_microbatches_remaining - 1):
            recv_prev = False

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
335
            p2p_communication.send_forward_backward_recv_forward_backward(
336
337
338
339
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
                    timers=timers)

340
341
        # Put input_tensor and output_tensor_grad in data structures in the
        # right location.
342
343
344
        if recv_prev:
            input_tensors[next_forward_model_chunk_id].append(input_tensor)
        if recv_next:
345
346
            output_tensor_grads[next_backward_model_chunk_id].append(
                output_tensor_grad)
347

348
    # Run cooldown backward passes (flush out pipeline).
349
350
351
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
352
                p2p_communication.recv_backward(timers=timers))
353
354
355
356
357
358
359
360
361
362
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
            recv_next = True
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                if next_backward_model_chunk_id == (num_model_chunks - 1):
                    recv_next = False
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
363
                p2p_communication.send_backward_recv_backward(
364
                    input_tensor_grad, recv_next=recv_next, timers=timers))
365
366
367
368

    return losses_reduced


369
370
371
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
                                                     model, optimizer, timers,
                                                     forward_only):
372
373
374
375
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

    Returns dictionary with losses if the last stage, empty dict otherwise."""
376
377
    timers = get_timers()

378
379
380
381
382
383
384
385
386
387
388
389
390
391
    assert len(model) == 1
    model = model[0]

    # Compute number of warmup microbatches.
    num_microbatches = get_num_microbatches()
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

392
393
394
395
396
397
    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
398
399
400
401
    losses_reduced = []

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
402
        input_tensor = p2p_communication.recv_forward(timers=timers)
403
404
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
405
        p2p_communication.send_forward(output_tensor, timers=timers)
406

407
408
409
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
410
411
412
413
414

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
415
        input_tensor = p2p_communication.recv_forward(timers=timers)
416
417
418
419
420
421
422
423

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
424
            p2p_communication.send_forward(output_tensor, timers=timers)
425
426
427
428

            if not last_iteration:
                input_tensor = p2p_communication.recv_forward(timers=timers)

429
        else:
430
            output_tensor_grad = \
431
                p2p_communication.send_forward_recv_backward(output_tensor,
432
                                                             timers=timers)
433

434
435
436
            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)
437

438
439
440
441
            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)
442
443
444
445
446
447
448

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
449
                p2p_communication.send_backward(input_tensor_grad, timers=timers)
450
            else:
451
                input_tensor = \
452
                    p2p_communication.send_backward_recv_forward(
453
                        input_tensor_grad, timers=timers)
454
455
456
457
458
459
460

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

461
            output_tensor_grad = p2p_communication.recv_backward(timers=timers)
462
463
464
465
466

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

467
            p2p_communication.send_backward(input_tensor_grad, timers=timers)
468
469

    return losses_reduced