test_sharded_ddp.py 3.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
Testing OssDdp class.
"""

import tempfile

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential

18
from fairscale.nn.data_parallel import ShardedDataParallel
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")


def test_on_cpu():
    run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"))


@skip_if_no_cuda
@skip_if_single_gpu
def test_on_gpu():
    run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))


def run_one_step(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

40
    # Any model works. Add one different buffer per rank
41
    model = Sequential(Linear(2, 3)).to(device)
42
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
43
44
45
46
47
48
49

    def weights_init(m):
        if isinstance(m, Linear):
            torch.nn.init.constant_(m.weight.data, 1.0)
            torch.nn.init.constant_(m.bias.data, 1.0)

    model.apply(weights_init)
50
    model.to(device)
51
52

    ddp = ShardedDataParallel(
53
54
55
56
57
        module=model,
        optimizer=torch.optim.SGD,
        optimizer_params={"lr": 0.01, "momentum": 0.99},
        world_size=world_size,
        broadcast_buffers=True,
58
59
    )
    optimizer = ddp.optimizer
60
    model = ddp.module
61

62
63
64
    # Different input per rank, allows for checking that the gradients have been properly reduced
    input_tensor = (torch.ones((64, 2)) * rank).to(device)
    output = ddp(input_tensor).abs().sum()
65
66
    output.backward()
    ddp.reduce()
67
68
69
70

    # Check that all the grads have been populated, for the shard
    for pg in optimizer.optim.param_groups:
        for param in pg["params"]:
71
72
73
74
            if param.shape == torch.Size([3, 2]):
                assert param.grad[0, 0].cpu() == torch.tensor([32.0])
            if param.shape == torch.Size([3]):
                assert param.grad[0].cpu() == torch.tensor([64.0])
75

76
77
78
    # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
    for b in model.buffers():
        assert b.cpu().item() == 0.0
79

80
81
    dist.destroy_process_group()

82
83
84
85

def run_test(backend, device, world_size=2):
    temp_file_name = tempfile.mkstemp()[1]
    mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
Min Xu's avatar
Min Xu committed
86
87
88
89
90
91
92
93


def run_eval_mode(_unused):
    """ Testing eval mode make sure this is no asserts. """
    dist.init_process_group(
        init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1
    )
    model = Sequential(Linear(2, 3), Linear(3, 4))
94
    optimizer_params = {"lr": 0.1, "momentum": 0.99}
95
    ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False)
96
    optimizer = ddp.optimizer
Min Xu's avatar
Min Xu committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

    ddp.eval()
    for _ in range(5):
        input_tensor = torch.rand((64, 2))
        output = ddp(input_tensor)

    ddp.train()
    try:
        for _ in range(5):
            input_tensor = torch.rand((64, 2))
            output = ddp(input_tensor)
    except RuntimeError:
        pass
    else:
        assert False, "Multiple forward passes on training mode should not pass"

113
114
    dist.destroy_process_group()

Min Xu's avatar
Min Xu committed
115
116
117

def test_eval_mode():
    mp.spawn(run_eval_mode, args=(), join=True)