test_sharded_ddp_features.py 21.1 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
Testing ShardedDDP
8
9
"""

10
from contextlib import suppress
11

12
import numpy as np
13
import pytest
14
15
16
17
18
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential

19
from fair_dev.testing.testing import (
20
    GPT2,
21
    SGDWithPausingCompute,
22
23
24
    available_devices,
    check_same_models_across_ranks,
    skip_if_less_than_four_gpu,
25
26
    skip_if_no_cuda,
    skip_if_single_gpu,
27
    temp_files_ctx,
28
)
29
30
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
31

32

33
34
35
36
37
def _get_mlp(tripwire: bool = False):
    if not tripwire:
        return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))

    class Tripwire(torch.nn.Module):
38
        """A model made to expose possible corner cases"""
39
40
41
42
43
44
45
46
47
48
49
50

        def __init__(self) -> None:
            super().__init__()
            self.model = Linear(2, 3, bias=False)

            # mismatched types in between trainable or not, can trip the buckets for instance
            self.register_parameter("tripwire", torch.nn.Parameter(torch.LongTensor((3, 3)), requires_grad=False))

        def forward(self, x):
            return self.model(x)

    return Tripwire()
51
52
53
54
55
56
57
58
59
60
61
62
63
64


class _DoubleInput(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = _get_mlp()

    def forward(self, x, y):
        x1 = self.mlp(x)
        x2 = self.mlp(y)
        return torch.cat((x1, x2), dim=1)


def run_one_step(
65
66
67
68
69
70
71
72
73
    rank,
    world_size,
    backend,
    device,
    temp_file_name,
    broadcast_buffers,
    grad_accumulation,
    reduce_buffer_size,
    optimizer_type,
74
    reduce_fp16=False,
75
76
):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
77
78
79
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

80
81
82
    torch.manual_seed(rank)
    np.random.seed(rank)

83
84
    # Any model works. Add one different buffer per rank
    model = _get_mlp()
85
    model.register_buffer("test_buffer", torch.ones(1) * rank)
86
87
88
    model.to(device)

    next(model.parameters()).requires_grad = False  # Test non-trainable parameters
89

90
91
92
93
94
    optimizer_settings = {"lr": 1e-3, "momentum": 0.99}
    if optimizer_type == SGDWithPausingCompute:
        optimizer_settings["rank"] = rank

    optimizer = OSS(params=model.parameters(), optim=optimizer_type, **optimizer_settings)
95
    ddp_model = ShardedDataParallel(
96
97
98
99
100
        model,
        optimizer,
        broadcast_buffers=broadcast_buffers,
        reduce_buffer_size=reduce_buffer_size,
        reduce_fp16=reduce_fp16,
101
    )
102

103
104
105
106
    # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
    check_same_models_across_ranks(
        ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers
    )
107

108
109
    # Optim loop
    def closure():
110
        ddp_model.zero_grad(set_to_none=True)
111
112
113
114

        with ddp_model.no_sync() if grad_accumulation else suppress():
            input_tensor = torch.rand((64, 2)).to(device)
            loss = ddp_model(input_tensor).abs().sum()
115
116
117
118
119
120
121
122

            # If grad_accumulation, we can check after the forward that the models are different
            # (not synced)
            if grad_accumulation:
                check_same_models_across_ranks(
                    ddp_model, dist.group.WORLD, params_should_be_equal=False, check_broadcast_buffers=True
                )

123
124
125
126
127
128
            loss.backward()
        return loss

    # The models should stay the same in between the ranks
    for i in range(5):
        _ = optimizer.step(closure=closure)
129
130
131
132
133

        # For a sync of all the streams
        if device.type == torch.device("cuda").type:
            torch.cuda.synchronize(device=device)

134
        # when running on cpu/gloo the "nodes" are not really different
135
        same_params = device == torch.device("cpu") or not grad_accumulation
136
137
138
        check_same_models_across_ranks(
            ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
        )
139

140
141
    dist.destroy_process_group()

142

143
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type):
144
145
146
147
148
149
150
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_one_step,
            args=(world_size, backend, device, temp_files[0], broadcast_buffers, grad_accumulation, reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
Min Xu's avatar
Min Xu committed
151
152


153
154
@skip_if_no_cuda
@skip_if_single_gpu
155
156
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
157
@pytest.mark.parametrize("reduce_buffer_size", [0, 2**20])
158
@pytest.mark.parametrize("optimizer_type", [torch.optim.SGD, SGDWithPausingCompute])
159
@pytest.mark.parametrize("reduce_fp16", [False, True])
160
161
162
163
164
165
166
167
@pytest.mark.parametrize(
    "setup",
    [
        [dist.Backend.NCCL, torch.device("cuda")],
        [dist.Backend.GLOO, torch.device("cpu")],
        [dist.Backend.GLOO, torch.device("cuda")],
    ],
)
168
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, setup):
169
    world_size = 2
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_one_step,
            args=(
                world_size,
                setup[0],
                setup[1],
                temp_files[0],
                broadcast_buffers,
                grad_accumulation,
                reduce_buffer_size,
                optimizer_type,
                reduce_fp16,
            ),
            nprocs=world_size,
            join=True,
        )
187
188


189
190
191
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    if device == "cuda":
192
193
194
195
196
197
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    model = _DoubleInput().to(device)
198
199
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
200
201
202

    # Optim loop
    def closure():
203
        ddp_model.zero_grad(set_to_none=True)
204
205
206
207
208
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

209
    for _ in range(5):
210
        _ = optimizer.step(closure=closure)
Min Xu's avatar
Min Xu committed
211

212
213
    dist.destroy_process_group()

Min Xu's avatar
Min Xu committed
214

215
@pytest.mark.parametrize("reduce_buffer_size", [0, 2**20])
216
217
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
@pytest.mark.parametrize("device", available_devices)
218
@skip_if_single_gpu
219
def test_inputs(reduce_buffer_size, backend, device):
220
221
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
222
223
224
    if backend == "nccl" and device == "cpu":
        pytest.skip("Incompatible combination, or cuda not available")
        return
225
226
227
228
229
230
231
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_two_inputs,
            args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
232
233
234
235
236
237


def test_ddp_attributes():
    # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
    # - is multi_device_module
    # - device_type
238
239
    with temp_files_ctx(num=1) as temp_files:
        dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
240

241
242
243
        model = Sequential(Linear(2, 3), Linear(3, 3))
        optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer)
244

245
246
        assert hasattr(ddp_model, "is_multi_device_module")
        assert hasattr(ddp_model, "device_type")
247
        assert hasattr(ddp_model, "module")
248
        dist.destroy_process_group()
249
250


251
def test_random_attributes():
252
253
254
    with temp_files_ctx(num=1) as temp_files:
        # Check that ShardedDDP exposes the original module's attributes
        dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
255

256
257
        model = Sequential(Linear(2, 3), Linear(3, 3))
        model.banana = "sweet"
258

259
260
        optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer)
261

262
263
        assert hasattr(ddp_model, "banana")
        assert not hasattr(ddp_model, "orange")
264

265
        dist.destroy_process_group()
266
267


268
def test_catch_grad_grad():
269
270
271
    with temp_files_ctx(num=1) as temp_files:
        # Check that ShardedDDP exposes the original module's attributes
        dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
272

273
274
275
276
277
        model = Sequential(Linear(2, 3), Linear(3, 3))
        model.train()
        chained_grad = torch.zeros_like(next(model.parameters()))
        chained_grad.requires_grad = True
        next(model.parameters()).grad = chained_grad
278

279
280
        optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer)
281

282
283
284
        inputs = torch.rand(100, 2)
        with pytest.raises(RuntimeError):
            _ = ddp_model(inputs)
285

286
        dist.destroy_process_group()
287
288


289
def test_mixed_types():
290
291
292
    with temp_files_ctx(num=1) as temp_files:
        # Check that ShardedDDP exposes the original module's attributes
        dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1)
293

294
        model = _get_mlp(tripwire=True)
295

296
297
298
299
        optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
        model = ShardedDataParallel(model, optimizer)
        input_tensor = torch.rand((2, 2))
        _ = model(input_tensor)
300

301
        dist.destroy_process_group()
302
303


304
def run_test_train_eval_change(rank, world_size, file):
305
    # Check that ShardedDDP handles the switch from training to eval properly
306
    dist.init_process_group(init_method="file://" + file, backend="gloo", rank=rank, world_size=world_size)
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

    model = _get_mlp()
    model.train()
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    loss = model(input_tensor).sum()
    loss.backward()  # make sure that the gradients are reduced

    # Wipe the gradients and switch to eval mode
    model.zero_grad()
    model.eval()
    _ = model(input_tensor)
    assert next(model.parameters()).grad is None or torch.norm(next(model.parameters()).grad) < 1e-6

    # Get back to training
    model = model.train()
    model(input_tensor).sum().backward()
    assert torch.norm(next(model.parameters()).grad) > 0.0

    dist.destroy_process_group()


330
331
def test_train_eval_change():
    world_size = 4
332
333
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
334
335
336
337
            run_test_train_eval_change,
            args=(world_size, temp_files[0]),
            nprocs=world_size,
            join=True,
338
        )
339
340


341
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
342
    # Check that the wrapped module can change devices
343
344
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
345

346
347
348
349
350
351
352
353
354
355
    model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()  # not device on purpose, test changing it after the fact
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(
        model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size
    )
    try:
        ddp_model.to(device)
        assert False, "Changing devices should be caught and not supported"
    except AssertionError:
        pass
356

357
358
359
    # Check that we can change the data type
    ddp_model.to(device=torch.device("cpu"), dtype=torch.float16)

360
361
362
363
364
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
365
@pytest.mark.parametrize("reduce_buffer_size", [0, 2**20])
366
def test_device_change(reduce_buffer_size):
367
    # Check that ShardedDDP handles a device change properly
368
    world_size = 2
369
    backend = "nccl"
370
371
372
373
374
375
376
377
    with temp_files_ctx(num=1) as temp_files:
        device = "cuda"
        mp.spawn(
            run_test_device_change,
            args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
378
379


380
381
382
383
384
385
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
    group = dist.init_process_group(
        init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size
    )
    torch.cuda.set_device(rank)

386
    model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
387
388
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size)
389
390
391
392
393
394
395
396
397
398
399
400

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(inputs)  # assert if the module has not been changed properly
    _ = outputs.norm().backward()

    ddp_model.eval()
    ddp_model(inputs)  # This will assert if eval() is not properly taken into account
    ddp_model(inputs)

    dist.destroy_process_group()


401
402
@skip_if_no_cuda
@skip_if_single_gpu
403
@pytest.mark.parametrize("reduce_buffer_size", [0, 2**20])
404
405
406
407
def test_training_change(reduce_buffer_size):
    world_size = 2
    backend = "nccl"
    device = "cuda"
408
409
410
411
412
413
414
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_training_change,
            args=(world_size, backend, device, temp_files[0], reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
415
416


417
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
418
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
419
420
421

    model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
422
423
    model.to(device)  # in pytorch 1.5 syncBN switches to the default device/cpu

424
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    ddp_model = ShardedDataParallel(model, optimizer)

    assert isinstance(model[1], torch.nn.SyncBatchNorm)
    # Ensures sync batch norm handles have been added
    ddp_model(torch.randn(2, 2).to(device))
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_sync_batch_norm():
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
    backend = "gloo"
    device = "cuda"
440
441
442
443
444
445
446
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_ddp_sync_batch_norm,
            args=(world_size, backend, device, temp_files[0]),
            nprocs=world_size,
            join=True,
        )
447
448


449
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
450
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
451
452
453
454
455
456
457
458
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
459
460
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
461
462
463
464
465
466
467
468
469
470
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
471
472
473
474
475
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)
476
477
478
479
480
481
482
483
484

    dist.destroy_process_group()


def test_two_optimizers():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    device = "cpu"
485
486
487
488
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_two_optimizers, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True
        )
489
490


491
def run_test_gpt2(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
492
    INPUT_DIM = 16
493
494
495
496
497
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
498
    torch.cuda.set_device(rank)
499
500
501
502

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
503
        embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
504
    )
505
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
506
    ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
507

508
509
510
    # Move the model to another device post-construction
    model = model.to(device)

511
    # Optim loop
512
513
    set_to_none = True

514
    def closure():
515
516
517
518
        nonlocal set_to_none
        ddp_model.zero_grad(set_to_none=set_to_none)
        set_to_none = not set_to_none

519
520
521
522
523
524
525
526
527
528
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

529
530
531
532
        # Stress test the .to() method
        ddp_model.to(device=device, dtype=torch.float16)
        ddp_model.to(device=device, dtype=torch.float32)

533
534
535
536
537
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
538
@pytest.mark.parametrize("world_size", [1, 2])
539
@pytest.mark.parametrize("reduce_buffer", [2**23, 2**40])
540
def test_gpt2(world_size, reduce_buffer):
541
    # Check that having trainable unused params is fine
542
543
    backend = "gloo"
    device = "cuda"
544
    with temp_files_ctx(num=1) as temp_files:
545
546
547
548
549
550
        mp.spawn(
            run_test_gpt2,
            args=(world_size, backend, device, temp_files[0], reduce_buffer),
            nprocs=world_size,
            join=True,
        )
551
552


553
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    # Only work with the even ranks, to check that the global_rank indexing is properly used
    dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)

    sub_group_ranks = [0, 2]
    process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)

    # Make sure that all the ranks get different training data
    # So that the sync check in between their models is meaningful
    torch.manual_seed(rank)
    np.random.seed(rank)

    # Standard deep learning setup
    device = "cuda"
    torch.cuda.set_device(rank)

    epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
    loss_fn = torch.nn.L1Loss().to(device)

    def check(optimizer, model):
        # Just run a couple of epochs, check that the model is properly updated
        for _ in range(epochs):
            target = torch.rand((batch, target_width), device=device)
            inputs = torch.rand((batch, input_width), device=device)

            def closure():
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                return loss

            _ = optimizer.step(closure=closure)

            # Check that all the params are the same on all ranks
588
589
590
            check_same_models_across_ranks(
                model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
            )
591
592
593
594
595
596
597
598

    if rank in sub_group_ranks:
        # Model not-fitting in the broadcast bucket
        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
            device
        )

        # With SGD, Momentum is required to get a state to shard
599
600
601
602
        optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
        model = ShardedDataParallel(
            model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
        )
603
604
605
606
607
        check(optimizer, model)

    dist.destroy_process_group(process_group)


608
@skip_if_less_than_four_gpu
609
@pytest.mark.parametrize("reduce_buffer_size", [0, 2**20])
610
611
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend):
612
    world_size = 4
613
614
615
616
617
618
619
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_multiple_groups,
            args=(world_size, temp_files[0], backend, reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )