test_sharded_ddp.py 25.7 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
Testing ShardedDDP
8
9
"""

10
from contextlib import suppress
11
import copy
12
import tempfile
13
from typing import List
14

15
import numpy as np
16
import torch
17
from torch.cuda.amp import GradScaler as TorchGradScaler
18
19
20
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
21
from torch.nn.parallel import DistributedDataParallel as DDP
22

23
from fairscale.nn.data_parallel import ShardedDataParallel
24
from fairscale.optim import OSS
25
from fairscale.optim.grad_scaler import ShardedGradScaler
26
27
28
29
30
31
32
33
from fairscale.utils.testing import (
    GPT2,
    check_same_model_params,
    skip_if_less_four_gpu,
    skip_if_no_cuda,
    skip_if_py38,
    skip_if_single_gpu,
)
34

35
36
37
38
39
40
41

def run_one_step(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

42
43
44
45
46
47
48
49
50
    torch.manual_seed(rank)
    np.random.seed(rank)

    def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None:
        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

51
52
        next(model.parameters()).requires_grad = False  # Test non-trainable parameters

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)

        def check_same_model_params(same_params: bool):
            # Check that all the params are the same on all ranks
            # This should be true with and without broadcast_buffers, we don't have any real buffer here
            receptacle: List[torch.Tensor] = []

            if dist.get_backend() != "nccl":
                for pg in optimizer.param_groups:
                    for p in pg["params"]:
                        # Check the params
                        receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else []
                        dist.gather(p, receptacle, dst=0)
                        if rank == 0:
                            for sync_p in receptacle[1:]:
                                if same_params:
                                    assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_p)
                                    ), "Gradients should not have been synced"

                # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
                if broadcast_buffers:
                    for b in ddp_model.buffers():
                        receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else []
                        dist.gather(b, receptacle, dst=0)
                        if rank == 0:
                            for sync_b in receptacle[1:]:
                                if same_params:
                                    assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_b)
                                    ), "Gradients should not have been synced"

                        assert b.cpu().item() == 0.0

        # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
        check_same_model_params(same_params=True)

        # Optim loop
        def closure():
            optimizer.zero_grad()

            with ddp_model.no_sync() if grad_accumulation else suppress():
                input_tensor = torch.rand((64, 2)).to(device)
                loss = ddp_model(input_tensor).abs().sum()
                loss.backward()
            return loss

        # The models should stay the same in between the ranks
        for i in range(5):
            _ = optimizer.step(closure=closure)
            # when running on cpu/gloo the "nodes" are not really different
            same_params = device == torch.device("cpu") or grad_accumulation
            check_same_model_params(same_params=same_params)

    check(broadcast_buffers=False)
    check(broadcast_buffers=True)
    check(broadcast_buffers=False, grad_accumulation=True)
    check(broadcast_buffers=True, grad_accumulation=True)
116
117
    dist.destroy_process_group()

118
119
120
121

def run_test(backend, device, world_size=2):
    temp_file_name = tempfile.mkstemp()[1]
    mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
Min Xu's avatar
Min Xu committed
122
123


124
125
@skip_if_no_cuda
@skip_if_single_gpu
126
def test_step_gpu():
127
128
129
    run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))


130
131
132
133
134
@skip_if_py38
def test_step_cpu():
    run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"))


135
136
137
138
139
140
141
142
def run_ddp_parity(rank, world_size, backend, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)

    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)
143
144
145
146
    NUMBER_BATCHS = 5
    INPUTS = 2
    BATCH_SIZE = 32

147
    def check_parity(amp: bool, accumulate: bool, change_train_graph: bool, manual_reduction: bool):
148
149

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
150
        def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

168
169
170
171
172
173
174
            if not _manual_reduction:
                step()
            else:
                with model.no_sync():
                    step()

                model.reduce()
175

176
        # Any model works. Add one different buffer per rank
177
        model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
178
179
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)
180

181
182
183
184
185
        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

        sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99)
186
187
188
        sharded_ddp_model = ShardedDataParallel(
            module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
        )
189

190
        ddp_model_single = copy.deepcopy(model)
191
192
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99)
        ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
193

194
195
        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None
196

197
        # The model should be synchronized in between the ranks at construction time, check that
198
        check_same_model_params(sharded_ddp_model, ddp_model)
199

200
201
202
        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
            input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device)
203
204

            def closure_ddp(input_tensor=input_tensor):
205
                return closure(ddp_model, ddp_scaler, input_tensor, accumulate)
206
207

            def closure_sharded(input_tensor=input_tensor):
208
209
210
                return closure(
                    sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate, _manual_reduction=manual_reduction
                )
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

227
            check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")
228

229
230
231
232
233
234
235
            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()
                ).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
                check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
236

237
238
    # Test all combinations: AMP, Accumulate, Change train graph
    amp_tests = [False]
239
    if hasattr(torch.cuda.amp, "autocast"):
240
241
242
243
        amp_tests.append(True)

    for accumulate in [False, True]:
        for change_train_graph in [False, True]:
244
245
246
247
248
249
250
251
252
253
254
255
            manual_reductions = [False, True] if not accumulate and not change_train_graph else [False]
            for manual_reduction in manual_reductions:
                for amp in amp_tests:
                    print(
                        f"Checking configuration: accumulate {accumulate} - change train graph {change_train_graph} - amp {amp} - manual reduction {manual_reduction}"
                    )
                    check_parity(
                        amp=amp,
                        accumulate=accumulate,
                        change_train_graph=change_train_graph,
                        manual_reduction=manual_reduction,
                    )
256

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity():
    temp_file_name = tempfile.mkstemp()[1]
    world_size = torch.cuda.device_count()
    backend = dist.Backend.NCCL
    mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)


def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2

    sharded_optimizer = OSS(
        params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )
    sharded_optimizer_2 = OSS(
        params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )

    sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99)
    ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99)
    ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

    def check_same_model_params():
        for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
            for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                assert torch.allclose(
                    p, ddp_p, atol=1e-3
                ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
        for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
            assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"

    check_same_model_params()  # The models should stay the same in between the ranks

    for i in range(20):
        input_tensor = torch.rand((64, 2)).to(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        check_same_model_params()

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity_two_optim():
    temp_file_name = tempfile.mkstemp()[1]
    world_size = 2
    backend = dist.Backend.NCCL
    mp.spawn(run_ddp_parity_two_optim, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)


339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    class _DoubleInput(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))

        def forward(self, x, y):
            x1 = self.mlp(x)
            x2 = self.mlp(y)
            return torch.cat((x1, x2), dim=1)

    model = _DoubleInput().to(device)

    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        _ = optimizer.step(closure=closure)
Min Xu's avatar
Min Xu committed
373

374
375
    dist.destroy_process_group()

Min Xu's avatar
Min Xu committed
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def test_inputs():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cpu"
    mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)


def test_ddp_attributes():
    # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
    # - is multi_device_module
    # - device_type

    url = "file://" + tempfile.mkstemp()[1]
392
    dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
393
394
395
396
397
398
399
400
401
402

    model = Sequential(Linear(2, 3), Linear(3, 3))
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "is_multi_device_module")
    assert hasattr(ddp_model, "device_type")
    dist.destroy_process_group()


403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def test_random_attributes():
    # Check that ShardedDDP exposes the original module's attributes

    url = "file://" + tempfile.mkstemp()[1]
    dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)

    model = Sequential(Linear(2, 3), Linear(3, 3))
    model.banana = "sweet"

    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "banana")
    assert not hasattr(ddp_model, "orange")

    dist.destroy_process_group()


421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def run_test_device_change(rank, world_size, backend, device, temp_file_name):
    # Check that the wrapped module can change devices

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)

    model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)
    ddp_model.to(device)

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(inputs)  # assert if the module has not been changed properly
    loss = outputs.norm().backward()

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_device_change():
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(run_test_device_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)


450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
def run_test_training_change(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    group = dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, process_group=group)

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(inputs)  # assert if the module has not been changed properly
    _ = outputs.norm().backward()

    ddp_model.eval()
    ddp_model(inputs)  # This will assert if eval() is not properly taken into account
    ddp_model(inputs)

    dist.destroy_process_group()


def test_training_change():
    world_size = 8
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cpu"
    mp.spawn(run_test_training_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)


476
477
478
479
480
481
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)

    model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
482
483
    model.to(device)  # in pytorch 1.5 syncBN switches to the default device/cpu

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    assert isinstance(model[1], torch.nn.SyncBatchNorm)
    # Ensures sync batch norm handles have been added
    ddp_model(torch.randn(2, 2).to(device))
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_sync_batch_norm():
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(
502
        run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
503
504
505
    )


506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    class _DoubleInput(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))

        def forward(self, x, y):
            x1 = self.mlp(x)
            x2 = self.mlp(y)
            return torch.cat((x1, x2), dim=1)

    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
540
541
542
543
544
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)
545
546
547
548
549
550
551
552
553
554

    dist.destroy_process_group()


def test_two_optimizers():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cpu"
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)


def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
    INPUT_DIM = 32
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
        embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
    ).to(device)
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_gpt2():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677


def run_test_multiple_groups(rank, world_size, tempfile_name, backend):
    # Only work with the even ranks, to check that the global_rank indexing is properly used
    dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)

    sub_group_ranks = [0, 2]
    process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)

    # Make sure that all the ranks get different training data
    # So that the sync check in between their models is meaningful
    torch.manual_seed(rank)
    np.random.seed(rank)

    # Standard deep learning setup
    device = "cuda"
    torch.cuda.set_device(rank)

    epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
    loss_fn = torch.nn.L1Loss().to(device)

    def check(optimizer, model):
        # Just run a couple of epochs, check that the model is properly updated
        for _ in range(epochs):
            target = torch.rand((batch, target_width), device=device)
            inputs = torch.rand((batch, input_width), device=device)

            def closure():
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                return loss

            _ = optimizer.step(closure=closure)

            # Check that all the params are the same on all ranks
            for pg in optimizer.param_groups:
                for p in pg["params"]:
                    receptacle = [p.clone() for _ in sub_group_ranks]
                    dist.all_gather(receptacle, p, group=process_group)
                    if rank == 0:
                        for sync_p in receptacle[1:]:
                            assert torch.all(
                                torch.eq(receptacle[0], sync_p)
                            ), "Models differ in between ranks {} - {}".format(
                                torch.norm(receptacle[0]), torch.norm(sync_p)
                            )

    if rank in sub_group_ranks:
        # Model not-fitting in the broadcast bucket
        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
            device
        )

        # With SGD, Momentum is required to get a state to shard
        optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group)
        model = ShardedDataParallel(model, optimizer, process_group=process_group)
        check(optimizer, model)

    dist.destroy_process_group(process_group)


@skip_if_less_four_gpu
def test_multiple_groups():
    world_size = 4
    temp_file_name = tempfile.mkstemp()[1]

    for backend in ["gloo", "nccl"]:
        print("Testing backend ", backend)
        mp.spawn(
            run_test_multiple_groups, args=(world_size, temp_file_name, backend), nprocs=world_size, join=True,
        )

    mp.spawn(
        run_test_multiple_groups, args=(world_size, temp_file_name, "gloo"), nprocs=world_size, join=True,
    )