test_sharded_ddp_features.py 16.3 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
7
Testing ShardedDDP
8
9
"""

10
from contextlib import suppress
11
12
import tempfile

13
import numpy as np
14
import pytest
15
16
17
18
19
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential

20
from fairscale.nn.data_parallel import ShardedDataParallel
21
from fairscale.optim import OSS
22
23
from fairscale.utils.testing import (
    GPT2,
24
    SGDWithPausingCompute,
25
26
27
    available_devices,
    check_same_models_across_ranks,
    skip_if_less_than_four_gpu,
28
29
30
    skip_if_no_cuda,
    skip_if_single_gpu,
)
31

32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def _get_mlp():
    return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))


class _DoubleInput(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = _get_mlp()

    def forward(self, x, y):
        x1 = self.mlp(x)
        x2 = self.mlp(y)
        return torch.cat((x1, x2), dim=1)


def run_one_step(
49
50
51
52
53
54
55
56
57
    rank,
    world_size,
    backend,
    device,
    temp_file_name,
    broadcast_buffers,
    grad_accumulation,
    reduce_buffer_size,
    optimizer_type,
58
59
):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
60
61
62
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

63
64
65
    torch.manual_seed(rank)
    np.random.seed(rank)

66
67
68
69
70
71
    # Any model works. Add one different buffer per rank
    model = _get_mlp()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    next(model.parameters()).requires_grad = False  # Test non-trainable parameters
72

73
74
75
76
77
    optimizer_settings = {"lr": 1e-3, "momentum": 0.99}
    if optimizer_type == SGDWithPausingCompute:
        optimizer_settings["rank"] = rank

    optimizer = OSS(params=model.parameters(), optim=optimizer_type, **optimizer_settings)
78
79
80
    ddp_model = ShardedDataParallel(
        model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
    )
81

82
83
84
85
    # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
    check_same_models_across_ranks(
        ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers
    )
86

87
88
89
90
91
92
93
94
95
96
97
98
99
    # Optim loop
    def closure():
        optimizer.zero_grad()

        with ddp_model.no_sync() if grad_accumulation else suppress():
            input_tensor = torch.rand((64, 2)).to(device)
            loss = ddp_model(input_tensor).abs().sum()
            loss.backward()
        return loss

    # The models should stay the same in between the ranks
    for i in range(5):
        _ = optimizer.step(closure=closure)
100
101
102
103
104

        # For a sync of all the streams
        if device.type == torch.device("cuda").type:
            torch.cuda.synchronize(device=device)

105
        # when running on cpu/gloo the "nodes" are not really different
106
        same_params = device == torch.device("cpu") or not grad_accumulation
107
108
109
        check_same_models_across_ranks(
            ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
        )
110

111
112
    dist.destroy_process_group()

113

114
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type):
115
    temp_file_name = tempfile.mkstemp()[1]
116
117
118
119
120
121
    mp.spawn(
        run_one_step,
        args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
Min Xu's avatar
Min Xu committed
122
123


124
125
@skip_if_no_cuda
@skip_if_single_gpu
126
127
128
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
129
130
131
132
133
134
135
136
137
138
@pytest.mark.parametrize("optimizer_type", [torch.optim.SGD, SGDWithPausingCompute])
@pytest.mark.parametrize(
    "setup",
    [
        [dist.Backend.NCCL, torch.device("cuda")],
        [dist.Backend.GLOO, torch.device("cpu")],
        [dist.Backend.GLOO, torch.device("cuda")],
    ],
)
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, setup):
139
    world_size = 2
140
    temp_file_name = tempfile.mkstemp()[1]
141

142
143
144
145
146
147
148
149
150
151
152
153
154
155
    mp.spawn(
        run_one_step,
        args=(
            world_size,
            setup[0],
            setup[1],
            temp_file_name,
            broadcast_buffers,
            grad_accumulation,
            reduce_buffer_size,
            optimizer_type,
        ),
        nprocs=world_size,
        join=True,
156
    )
157
158


159
160
161
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    if device == "cuda":
162
163
164
165
166
167
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    model = _DoubleInput().to(device)
168
169
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
170
171
172
173
174
175
176
177
178
179
180

    # Optim loop
    def closure():
        optimizer.zero_grad()
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        _ = optimizer.step(closure=closure)
Min Xu's avatar
Min Xu committed
181

182
183
    dist.destroy_process_group()

Min Xu's avatar
Min Xu committed
184

185
186
187
188
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
@pytest.mark.parametrize("device", available_devices)
def test_inputs(reduce_buffer_size, backend, device):
189
190
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
191
192
193
194
195
196
197
198
199
200
    if backend == "nccl" and device == "cpu":
        pytest.skip("Incompatible combination, or cuda not available")
        return

    mp.spawn(
        run_test_two_inputs,
        args=(world_size, backend, device, tempfile.mkstemp()[1], reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
201
202
203
204
205
206


def test_ddp_attributes():
    # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
    # - is multi_device_module
    # - device_type
207
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
208
209

    model = Sequential(Linear(2, 3), Linear(3, 3))
210
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
211
212
213
214
215
216
217
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "is_multi_device_module")
    assert hasattr(ddp_model, "device_type")
    dist.destroy_process_group()


218
219
def test_random_attributes():
    # Check that ShardedDDP exposes the original module's attributes
220
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
221
222
223
224

    model = Sequential(Linear(2, 3), Linear(3, 3))
    model.banana = "sweet"

225
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
226
227
228
229
230
231
232
233
    ddp_model = ShardedDataParallel(model, optimizer)

    assert hasattr(ddp_model, "banana")
    assert not hasattr(ddp_model, "orange")

    dist.destroy_process_group()


234
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
235
    # Check that the wrapped module can change devices
236
237
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
238

239
240
241
242
243
244
245
246
247
248
    model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()  # not device on purpose, test changing it after the fact
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(
        model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size
    )
    try:
        ddp_model.to(device)
        assert False, "Changing devices should be caught and not supported"
    except AssertionError:
        pass
249
250
251
252
253
254

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
255
256
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_device_change(reduce_buffer_size):
257
258
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
259
    backend = "nccl"
260
261
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
262
263
264
265
266
267
    mp.spawn(
        run_test_device_change,
        args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
268
269


270
271
272
273
274
275
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
    group = dist.init_process_group(
        init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size
    )
    torch.cuda.set_device(rank)

276
    model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
277
278
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size)
279
280
281
282
283
284
285
286
287
288
289
290

    inputs = torch.rand((10, 2), device=device)
    outputs = ddp_model(inputs)  # assert if the module has not been changed properly
    _ = outputs.norm().backward()

    ddp_model.eval()
    ddp_model(inputs)  # This will assert if eval() is not properly taken into account
    ddp_model(inputs)

    dist.destroy_process_group()


291
292
293
294
295
296
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_training_change(reduce_buffer_size):
    world_size = 2
    backend = "nccl"
297
    temp_file_name = tempfile.mkstemp()[1]
298
299
300
301
302
303
304
    device = "cuda"
    mp.spawn(
        run_test_training_change,
        args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
        nprocs=world_size,
        join=True,
    )
305
306


307
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
308
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
309
310
311

    model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
312
313
    model.to(device)  # in pytorch 1.5 syncBN switches to the default device/cpu

314
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    ddp_model = ShardedDataParallel(model, optimizer)

    assert isinstance(model[1], torch.nn.SyncBatchNorm)
    # Ensures sync batch norm handles have been added
    ddp_model(torch.randn(2, 2).to(device))
    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_sync_batch_norm():
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(
332
        run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True
333
334
335
    )


336
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
337
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
338
339
340
341
342
343
344
345
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
346
347
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
348
349
350
351
352
353
354
355
356
357
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
358
359
360
361
362
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)
363
364
365
366
367
368
369
370
371
372

    dist.destroy_process_group()


def test_two_optimizers():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cpu"
373
374
375
376
    mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)


def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
377
    INPUT_DIM = 16
378
379
380
381
382
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
383
    torch.cuda.set_device(rank)
384
385
386
387

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
388
        embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
389
    ).to(device)
390
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
391
    ddp_model = ShardedDataParallel(model, optimizer)
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410

    # Optim loop
    def closure():
        optimizer.zero_grad()
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()


@skip_if_no_cuda
@skip_if_single_gpu
411
412
413
@pytest.mark.parametrize("world_size", [1, 2])
def test_gpt2(world_size):
    # Check that having trainable unused params is fine
414
415
416
417
    backend = "gloo"
    temp_file_name = tempfile.mkstemp()[1]
    device = "cuda"
    mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
418
419


420
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    # Only work with the even ranks, to check that the global_rank indexing is properly used
    dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)

    sub_group_ranks = [0, 2]
    process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend=backend)

    # Make sure that all the ranks get different training data
    # So that the sync check in between their models is meaningful
    torch.manual_seed(rank)
    np.random.seed(rank)

    # Standard deep learning setup
    device = "cuda"
    torch.cuda.set_device(rank)

    epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
    loss_fn = torch.nn.L1Loss().to(device)

    def check(optimizer, model):
        # Just run a couple of epochs, check that the model is properly updated
        for _ in range(epochs):
            target = torch.rand((batch, target_width), device=device)
            inputs = torch.rand((batch, input_width), device=device)

            def closure():
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                return loss

            _ = optimizer.step(closure=closure)

            # Check that all the params are the same on all ranks
455
456
457
            check_same_models_across_ranks(
                model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
            )
458
459
460
461
462
463
464
465

    if rank in sub_group_ranks:
        # Model not-fitting in the broadcast bucket
        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
            device
        )

        # With SGD, Momentum is required to get a state to shard
466
467
468
469
        optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
        model = ShardedDataParallel(
            model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
        )
470
471
472
473
474
        check(optimizer, model)

    dist.destroy_process_group(process_group)


475
476
477
478
@skip_if_less_than_four_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend):
479
480
481
482
    world_size = 4
    temp_file_name = tempfile.mkstemp()[1]

    mp.spawn(
483
484
485
486
        run_test_multiple_groups,
        args=(world_size, temp_file_name, backend, reduce_buffer_size),
        nprocs=world_size,
        join=True,
487
    )