test_sharded_ddp.py 24.6 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
Testing ShardedDDP
8
9
"""

10
from contextlib import suppress
11
import copy
12
13
import tempfile

14
import numpy as np
15
import pytest
16
import torch
17
from torch.cuda.amp import GradScaler as TorchGradScaler
18
19
20
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
21
from torch.nn.parallel import DistributedDataParallel as DDP
22

23
from fairscale.nn.data_parallel import ShardedDataParallel
24
from fairscale.optim import OSS
25
from fairscale.optim.grad_scaler import ShardedGradScaler
26
27
from fairscale.utils.testing import (
    GPT2,
28
    available_devices,
29
    check_same_model_params,
30
31
    check_same_models_across_ranks,
    skip_if_less_than_four_gpu,
32
33
34
35
    skip_if_no_cuda,
    skip_if_py38,
    skip_if_single_gpu,
)
36

37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def _get_mlp():
    return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))


class _DoubleInput(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = _get_mlp()

    def forward(self, x, y):
        x1 = self.mlp(x)
        x2 = self.mlp(y)
        return torch.cat((x1, x2), dim=1)


def run_one_step(
    rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size,
):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
57
58
59
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

60
61
62
    torch.manual_seed(rank)
    np.random.seed(rank)

63
64
65
66
67
68
    # Any model works. Add one different buffer per rank
    model = _get_mlp()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    next(model.parameters()).requires_grad = False  # Test non-trainable parameters
69

70
71
72
73
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(
        model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
    )
74

75
76
77
78
    # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
    check_same_models_across_ranks(
        ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers
    )
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    # Optim loop
    def closure():
        optimizer.zero_grad()

        with ddp_model.no_sync() if grad_accumulation else suppress():
            input_tensor = torch.rand((64, 2)).to(device)
            loss = ddp_model(input_tensor).abs().sum()
            loss.backward()
        return loss

    # The models should stay the same in between the ranks
    for i in range(5):
        _ = optimizer.step(closure=closure)
        # when running on cpu/gloo the "nodes" are not really different
        same_params = device == torch.device("cpu") or grad_accumulation
        check_same_models_across_ranks(
            ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
        )
98

99
100
    dist.destroy_process_group()

101

102
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size):
103
    temp_file_name = tempfile.mkstemp()[1]
104
105
106
107
108
109
    mp.spawn(
        run_one_step,
        args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
Min Xu's avatar
Min Xu committed
110
111


112
113
@skip_if_no_cuda
@skip_if_single_gpu
114
115
116
117
118
119
120
121
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_gpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
    world_size = 2
    run_test(
        dist.Backend.NCCL, torch.device("cuda"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
    )
122
123


124
@skip_if_py38
125
126
127
128
129
130
131
132
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_cpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
    world_size = 2
    run_test(
        dist.Backend.GLOO, torch.device("cpu"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
    )
133
134


135
136
137
138
def run_ddp_parity(
    rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph
):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
139
140
141
142
143

    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)
144
    NUMBER_BATCHS = 5
145
    BATCH_SIZE = 8
146

147
    def check_parity(amp: bool, manual_reduction: bool):
148
149

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
150
        def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

168
169
170
171
172
173
174
            if not _manual_reduction:
                step()
            else:
                with model.no_sync():
                    step()

                model.reduce()
175

176
        # Any model works. Add one different buffer per rank
177
        model = _get_mlp()
178
179
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)
180

181
182
183
184
        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

185
        sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99)
186
        sharded_ddp_model = ShardedDataParallel(
187
188
189
190
            module=model,
            sharded_optimizer=sharded_optimizer,
            broadcast_buffers=True,
            reduce_buffer_size=reduce_buffer_size,
191
        )
192

193
        ddp_model_single = copy.deepcopy(model)
194
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99)
195
        ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
196

197
198
        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None
199

200
        # The model should be synchronized in between the ranks at construction time, check that
201
        check_same_model_params(sharded_ddp_model, ddp_model)
202

203
204
        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
205
            input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
206
207

            def closure_ddp(input_tensor=input_tensor):
208
                return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
209
210

            def closure_sharded(input_tensor=input_tensor):
211
                return closure(
212
213
214
215
216
                    sharded_ddp_model,
                    sharded_ddp_scaler,
                    input_tensor,
                    grad_accumulation,
                    _manual_reduction=manual_reduction,
217
                )
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

234
            check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
235

236
237
238
239
240
241
242
            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()
                ).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
                check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
243

244
    # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
245
    amp_tests = [False]
246
    if hasattr(torch.cuda.amp, "autocast"):
247
248
        amp_tests.append(True)

249
250
251
252
253
254
255
256
257
258
259
260
261
262
    manual_reductions = [False, True] if not grad_accumulation and not change_train_graph else [False]
    for manual_reduction in manual_reductions:
        for amp in amp_tests:
            print(
                f"Checking configuration: accumulate {grad_accumulation}"
                + f" - change train graph {change_train_graph}"
                + f" - amp {amp}"
                + f" - manual reduction {manual_reduction}"
                + f" - buffers {reduce_buffer_size}",
                flush=True,
            )
            check_parity(
                amp=amp, manual_reduction=manual_reduction,
            )
263

264
265
266
267
268
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
269
270
271
272
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("change_train_graph", [True, False])
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph):
273
274
    world_size = torch.cuda.device_count()
    backend = dist.Backend.NCCL
275
276
277
278
279
280
    mp.spawn(
        run_ddp_parity,
        args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size, grad_accumulation, change_train_graph),
        nprocs=world_size,
        join=True,
    )
281
282


283
284
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
285
286
287
288
289
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

290
291
292
    BATCHS = 20

    model = _get_mlp()
293
294
295
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2
296
    optim_settings = {"lr": 1e-3, "momentum": 0.99}
297

298
299
    sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, **optim_settings)
    sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, **optim_settings)
300

301
302
303
304
305
306
    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=[sharded_optimizer, sharded_optimizer_2],
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
    )
307
308

    ddp_model_single = copy.deepcopy(model)
309
310
    ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], **optim_settings)
    ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], **optim_settings)
311
312
    ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

313
314
315
316
317
    check_same_model_params(
        sharded_ddp_model,
        ddp_model,
        f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}",
    )
318

319
    for i in range(BATCHS):
320
321
322
323
324
325
326
327
328
        input_tensor = torch.rand((64, 2)).to(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()
329
        torch.cuda.synchronize(device)
330
331
332
333
334
335
336
337

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
338
339
340
341
342
        torch.cuda.synchronize(device)

        check_same_model_params(
            sharded_ddp_model, ddp_model, f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}",
        )
343
344
345
346
347
348

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
349
350
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_ddp_parity_two_optim(reduce_buffer_size):
351
352
    world_size = 2
    backend = dist.Backend.NCCL
353
354
355
356
357
358
    mp.spawn(
        run_ddp_parity_two_optim,
        args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
359
360


361
362
363
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    if device == "cuda":
364
365
366
367
368
369
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    model = _DoubleInput().to(device)
370
371
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
372
373
374
375
376
377
378
379
380
381
382

    # Optim loop
    def closure():
        optimizer.zero_grad()
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        _ = optimizer.step(closure=closure)
Min Xu's avatar
Min Xu committed
383

384
385
    dist.destroy_process_group()

Min Xu's avatar
Min Xu committed
386

387
388
389
390
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
@pytest.mark.parametrize("device", available_devices)
def test_inputs(reduce_buffer_size, backend, device):
391
392
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
393
394
395
396
397
398
399
400
401
402
    if backend == "nccl" and device == "cpu":
        pytest.skip("Incompatible combination, or cuda not available")
        return

    mp.spawn(
        run_test_two_inputs,
        args=(world_size, backend, device, tempfile.mkstemp()[1], reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
403
404
405
406
407
408


def test_ddp_attributes():
    # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
    # - is multi_device_module
    # - device_type
409
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
410
411

    model = Sequential(Linear(2, 3), Linear(3, 3))
412
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
413
414
415
416
417
418
419
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "is_multi_device_module")
    assert hasattr(ddp_model, "device_type")
    dist.destroy_process_group()


420
421
def test_random_attributes():
    # Check that ShardedDDP exposes the original module's attributes
422
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
423
424
425
426

    model = Sequential(Linear(2, 3), Linear(3, 3))
    model.banana = "sweet"

427
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
428
429
430
431
432
433
434
435
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "banana")
    assert not hasattr(ddp_model, "orange")

    dist.destroy_process_group()


436
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
437
    # Check that the wrapped module can change devices
438
439
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
440

441
442
443
444
445
446
447
448
449
450
    model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()  # not device on purpose, test changing it after the fact
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(
        model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size
    )
    try:
        ddp_model.to(device)
        assert False, "Changing devices should be caught and not supported"
    except AssertionError:
        pass
451
452
453
454
455
456

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
457
458
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_device_change(reduce_buffer_size):
459
460
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
461
    backend = "nccl"
462
463
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
464
465
466
467
468
469
    mp.spawn(
        run_test_device_change,
        args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
470
471


472
473
474
475
476
477
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
    group = dist.init_process_group(
        init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size
    )
    torch.cuda.set_device(rank)

478
    model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
479
480
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size)
481
482
483
484
485
486
487
488
489
490
491
492

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(inputs)  # assert if the module has not been changed properly
    _ = outputs.norm().backward()

    ddp_model.eval()
    ddp_model(inputs)  # This will assert if eval() is not properly taken into account
    ddp_model(inputs)

    dist.destroy_process_group()


493
494
495
496
497
498
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_training_change(reduce_buffer_size):
    world_size = 2
    backend = "nccl"
499
    temp_file_name = tempfile.mkstemp()[1]
500
501
502
503
504
505
506
    device = "cuda"
    mp.spawn(
        run_test_training_change,
        args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
507
508


509
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
510
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
511
512
513

    model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
514
515
    model.to(device)  # in pytorch 1.5 syncBN switches to the default device/cpu

516
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    ddp_model = ShardedDataParallel(model, optimizer)

    assert isinstance(model[1], torch.nn.SyncBatchNorm)
    # Ensures sync batch norm handles have been added
    ddp_model(torch.randn(2, 2).to(device))
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_sync_batch_norm():
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(
534
        run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
535
536
537
    )


538
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
539
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
540
541
542
543
544
545
546
547
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
548
549
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
550
551
552
553
554
555
556
557
558
559
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
560
561
562
563
564
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)
565
566
567
568
569
570
571
572
573
574

    dist.destroy_process_group()


def test_two_optimizers():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cpu"
575
576
577
578
    mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)


def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
579
    INPUT_DIM = 16
580
581
582
583
584
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
585
    torch.cuda.set_device(rank)
586
587
588
589

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
590
        embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
591
    ).to(device)
592
593
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=0)
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

    # Optim loop
    def closure():
        optimizer.zero_grad()
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_gpt2():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
620
621


622
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
    # Only work with the even ranks, to check that the global_rank indexing is properly used
    dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)

    sub_group_ranks = [0, 2]
    process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)

    # Make sure that all the ranks get different training data
    # So that the sync check in between their models is meaningful
    torch.manual_seed(rank)
    np.random.seed(rank)

    # Standard deep learning setup
    device = "cuda"
    torch.cuda.set_device(rank)

    epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
    loss_fn = torch.nn.L1Loss().to(device)

    def check(optimizer, model):
        # Just run a couple of epochs, check that the model is properly updated
        for _ in range(epochs):
            target = torch.rand((batch, target_width), device=device)
            inputs = torch.rand((batch, input_width), device=device)

            def closure():
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                return loss

            _ = optimizer.step(closure=closure)

            # Check that all the params are the same on all ranks
657
658
659
            check_same_models_across_ranks(
                model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
            )
660
661
662
663
664
665
666
667

    if rank in sub_group_ranks:
        # Model not-fitting in the broadcast bucket
        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
            device
        )

        # With SGD, Momentum is required to get a state to shard
668
669
670
671
        optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
        model = ShardedDataParallel(
            model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
        )
672
673
674
675
676
        check(optimizer, model)

    dist.destroy_process_group(process_group)


677
678
679
680
@skip_if_less_than_four_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend):
681
682
683
684
    world_size = 4
    temp_file_name = tempfile.mkstemp()[1]

    mp.spawn(
685
686
687
688
        run_test_multiple_groups,
        args=(world_size, temp_file_name, backend, reduce_buffer_size),
        nprocs=world_size,
        join=True,
689
    )