test_sharded_ddp_features.py 19.5 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
Testing ShardedDDP
8
9
"""

10
from contextlib import suppress
11
12
import tempfile

13
import numpy as np
14
import pytest
15
16
17
18
19
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential

20
from fairscale.nn.data_parallel import ShardedDataParallel
21
from fairscale.optim import OSS
22
23
from fairscale.utils.testing import (
    GPT2,
24
    SGDWithPausingCompute,
25
26
27
    available_devices,
    check_same_models_across_ranks,
    skip_if_less_than_four_gpu,
28
29
30
    skip_if_no_cuda,
    skip_if_single_gpu,
)
31

32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def _get_mlp(tripwire: bool = False):
    if not tripwire:
        return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))

    class Tripwire(torch.nn.Module):
        """A model made to expose possible corner cases
        """

        def __init__(self) -> None:
            super().__init__()
            self.model = Linear(2, 3, bias=False)

            # mismatched types in between trainable or not, can trip the buckets for instance
            self.register_parameter("tripwire", torch.nn.Parameter(torch.LongTensor((3, 3)), requires_grad=False))

        def forward(self, x):
            return self.model(x)

    return Tripwire()
52
53
54
55
56
57
58
59
60
61
62
63
64
65


class _DoubleInput(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = _get_mlp()

    def forward(self, x, y):
        x1 = self.mlp(x)
        x2 = self.mlp(y)
        return torch.cat((x1, x2), dim=1)


def run_one_step(
66
67
68
69
70
71
72
73
74
    rank,
    world_size,
    backend,
    device,
    temp_file_name,
    broadcast_buffers,
    grad_accumulation,
    reduce_buffer_size,
    optimizer_type,
75
    reduce_fp16=False,
76
77
):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
78
79
80
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

81
82
83
    torch.manual_seed(rank)
    np.random.seed(rank)

84
85
86
87
88
89
    # Any model works. Add one different buffer per rank
    model = _get_mlp()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    next(model.parameters()).requires_grad = False  # Test non-trainable parameters
90

91
92
93
94
95
    optimizer_settings = {"lr": 1e-3, "momentum": 0.99}
    if optimizer_type == SGDWithPausingCompute:
        optimizer_settings["rank"] = rank

    optimizer = OSS(params=model.parameters(), optim=optimizer_type, **optimizer_settings)
96
    ddp_model = ShardedDataParallel(
97
98
99
100
101
        model,
        optimizer,
        broadcast_buffers=broadcast_buffers,
        reduce_buffer_size=reduce_buffer_size,
        reduce_fp16=reduce_fp16,
102
    )
103

104
105
106
107
    # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
    check_same_models_across_ranks(
        ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers
    )
108

109
110
111
112
113
114
115
116
117
118
119
120
121
    # Optim loop
    def closure():
        optimizer.zero_grad()

        with ddp_model.no_sync() if grad_accumulation else suppress():
            input_tensor = torch.rand((64, 2)).to(device)
            loss = ddp_model(input_tensor).abs().sum()
            loss.backward()
        return loss

    # The models should stay the same in between the ranks
    for i in range(5):
        _ = optimizer.step(closure=closure)
122
123
124
125
126

        # For a sync of all the streams
        if device.type == torch.device("cuda").type:
            torch.cuda.synchronize(device=device)

127
        # when running on cpu/gloo the "nodes" are not really different
128
        same_params = device == torch.device("cpu") or not grad_accumulation
129
130
131
        check_same_models_across_ranks(
            ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
        )
132

133
134
    dist.destroy_process_group()

135

136
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type):
137
    temp_file_name = tempfile.mkstemp()[1]
138
139
140
141
142
143
    mp.spawn(
        run_one_step,
        args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
Min Xu's avatar
Min Xu committed
144
145


146
147
@skip_if_no_cuda
@skip_if_single_gpu
148
149
150
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
151
@pytest.mark.parametrize("optimizer_type", [torch.optim.SGD, SGDWithPausingCompute])
152
@pytest.mark.parametrize("reduce_fp16", [False, True])
153
154
155
156
157
158
159
160
@pytest.mark.parametrize(
    "setup",
    [
        [dist.Backend.NCCL, torch.device("cuda")],
        [dist.Backend.GLOO, torch.device("cpu")],
        [dist.Backend.GLOO, torch.device("cuda")],
    ],
)
161
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, setup):
162
    world_size = 2
163
    temp_file_name = tempfile.mkstemp()[1]
164

165
166
167
168
169
170
171
172
173
174
175
    mp.spawn(
        run_one_step,
        args=(
            world_size,
            setup[0],
            setup[1],
            temp_file_name,
            broadcast_buffers,
            grad_accumulation,
            reduce_buffer_size,
            optimizer_type,
176
            reduce_fp16,
177
178
179
        ),
        nprocs=world_size,
        join=True,
180
    )
181
182


183
184
185
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    if device == "cuda":
186
187
188
189
190
191
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    model = _DoubleInput().to(device)
192
193
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
194
195
196
197
198
199
200
201
202
203
204

    # Optim loop
    def closure():
        optimizer.zero_grad()
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        _ = optimizer.step(closure=closure)
Min Xu's avatar
Min Xu committed
205

206
207
    dist.destroy_process_group()

Min Xu's avatar
Min Xu committed
208

209
210
211
212
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
@pytest.mark.parametrize("device", available_devices)
def test_inputs(reduce_buffer_size, backend, device):
213
214
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
215
216
217
218
219
220
221
222
223
224
    if backend == "nccl" and device == "cpu":
        pytest.skip("Incompatible combination, or cuda not available")
        return

    mp.spawn(
        run_test_two_inputs,
        args=(world_size, backend, device, tempfile.mkstemp()[1], reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
225
226
227
228
229
230


def test_ddp_attributes():
    # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
    # - is multi_device_module
    # - device_type
231
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
232
233

    model = Sequential(Linear(2, 3), Linear(3, 3))
234
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
235
236
237
238
239
240
241
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "is_multi_device_module")
    assert hasattr(ddp_model, "device_type")
    dist.destroy_process_group()


242
243
def test_random_attributes():
    # Check that ShardedDDP exposes the original module's attributes
244
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
245
246
247
248

    model = Sequential(Linear(2, 3), Linear(3, 3))
    model.banana = "sweet"

249
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
250
251
252
253
254
255
256
257
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "banana")
    assert not hasattr(ddp_model, "orange")

    dist.destroy_process_group()


258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def test_catch_grad_grad():
    # Check that ShardedDDP exposes the original module's attributes
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)

    model = Sequential(Linear(2, 3), Linear(3, 3))
    model.train()
    chained_grad = torch.zeros_like(next(model.parameters()))
    chained_grad.requires_grad = True
    next(model.parameters()).grad = chained_grad

    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    inputs = torch.rand(100, 2)
    with pytest.raises(RuntimeError):
        _ = ddp_model(inputs)

    dist.destroy_process_group()


278
279
280
281
282
283
284
285
286
287
288
289
290
291
def test_mixed_types():
    # Check that ShardedDDP exposes the original module's attributes
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)

    model = _get_mlp(tripwire=True)

    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    _ = model(input_tensor)

    dist.destroy_process_group()


292
def run_test_train_eval_change(rank, world_size, file):
293
    # Check that ShardedDDP handles the switch from training to eval properly
294
    dist.init_process_group(init_method="file://" + file, backend="gloo", rank=rank, world_size=world_size)
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317

    model = _get_mlp()
    model.train()
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    loss = model(input_tensor).sum()
    loss.backward()  # make sure that the gradients are reduced

    # Wipe the gradients and switch to eval mode
    model.zero_grad()
    model.eval()
    _ = model(input_tensor)
    assert next(model.parameters()).grad is None or torch.norm(next(model.parameters()).grad) < 1e-6

    # Get back to training
    model = model.train()
    model(input_tensor).sum().backward()
    assert torch.norm(next(model.parameters()).grad) > 0.0

    dist.destroy_process_group()


318
319
320
321
322
323
324
325
def test_train_eval_change():
    world_size = 4
    temp_file_name = tempfile.mkstemp()[1]
    mp.spawn(
        run_test_train_eval_change, args=(world_size, temp_file_name), nprocs=world_size, join=True,
    )


326
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
327
    # Check that the wrapped module can change devices
328
329
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
330

331
332
333
334
335
336
337
338
339
340
    model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()  # not device on purpose, test changing it after the fact
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(
        model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size
    )
    try:
        ddp_model.to(device)
        assert False, "Changing devices should be caught and not supported"
    except AssertionError:
        pass
341

342
343
344
    # Check that we can change the data type
    ddp_model.to(device=torch.device("cpu"), dtype=torch.float16)

345
346
347
348
349
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
350
351
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_device_change(reduce_buffer_size):
352
    # Check that ShardedDDP handles a device change properly
353
    world_size = 2
354
    backend = "nccl"
355
356
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
357
358
359
360
361
362
    mp.spawn(
        run_test_device_change,
        args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
363
364


365
366
367
368
369
370
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
    group = dist.init_process_group(
        init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size
    )
    torch.cuda.set_device(rank)

371
    model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
372
373
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size)
374
375
376
377
378
379
380
381
382
383
384
385

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(inputs)  # assert if the module has not been changed properly
    _ = outputs.norm().backward()

    ddp_model.eval()
    ddp_model(inputs)  # This will assert if eval() is not properly taken into account
    ddp_model(inputs)

    dist.destroy_process_group()


386
387
388
389
390
391
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_training_change(reduce_buffer_size):
    world_size = 2
    backend = "nccl"
392
    temp_file_name = tempfile.mkstemp()[1]
393
394
395
396
397
398
399
    device = "cuda"
    mp.spawn(
        run_test_training_change,
        args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
400
401


402
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
403
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
404
405
406

    model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
407
408
    model.to(device)  # in pytorch 1.5 syncBN switches to the default device/cpu

409
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    ddp_model = ShardedDataParallel(model, optimizer)

    assert isinstance(model[1], torch.nn.SyncBatchNorm)
    # Ensures sync batch norm handles have been added
    ddp_model(torch.randn(2, 2).to(device))
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_sync_batch_norm():
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(
427
        run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
428
429
430
    )


431
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
432
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
433
434
435
436
437
438
439
440
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
441
442
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
443
444
445
446
447
448
449
450
451
452
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
453
454
455
456
457
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)
458
459
460
461
462
463
464
465
466
467

    dist.destroy_process_group()


def test_two_optimizers():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cpu"
468
469
470
471
    mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)


def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
472
    INPUT_DIM = 16
473
474
475
476
477
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
478
    torch.cuda.set_device(rank)
479
480
481
482

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
483
        embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
484
    )
485
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
486
    ddp_model = ShardedDataParallel(model, optimizer)
487

488
489
490
    # Move the model to another device post-construction
    model = model.to(device)

491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    # Optim loop
    def closure():
        optimizer.zero_grad()
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
509
510
511
@pytest.mark.parametrize("world_size", [1, 2])
def test_gpt2(world_size):
    # Check that having trainable unused params is fine
512
513
514
515
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
516
517


518
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    # Only work with the even ranks, to check that the global_rank indexing is properly used
    dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)

    sub_group_ranks = [0, 2]
    process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)

    # Make sure that all the ranks get different training data
    # So that the sync check in between their models is meaningful
    torch.manual_seed(rank)
    np.random.seed(rank)

    # Standard deep learning setup
    device = "cuda"
    torch.cuda.set_device(rank)

    epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
    loss_fn = torch.nn.L1Loss().to(device)

    def check(optimizer, model):
        # Just run a couple of epochs, check that the model is properly updated
        for _ in range(epochs):
            target = torch.rand((batch, target_width), device=device)
            inputs = torch.rand((batch, input_width), device=device)

            def closure():
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                return loss

            _ = optimizer.step(closure=closure)

            # Check that all the params are the same on all ranks
553
554
555
            check_same_models_across_ranks(
                model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
            )
556
557
558
559
560
561
562
563

    if rank in sub_group_ranks:
        # Model not-fitting in the broadcast bucket
        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
            device
        )

        # With SGD, Momentum is required to get a state to shard
564
565
566
567
        optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
        model = ShardedDataParallel(
            model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
        )
568
569
570
571
572
        check(optimizer, model)

    dist.destroy_process_group(process_group)


573
574
575
576
@skip_if_less_than_four_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend):
577
578
579
580
    world_size = 4
    temp_file_name = tempfile.mkstemp()[1]

    mp.spawn(
581
582
583
584
        run_test_multiple_groups,
        args=(world_size, temp_file_name, backend, reduce_buffer_size),
        nprocs=world_size,
        join=True,
585
    )