test_multiprocess_pipe.py 6.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
Testing MultiProcessPipe Module
"""

import functools
import tempfile

import pytest
import torch
import torch.distributed.autograd as dist_autograd
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn

from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe
from fairscale.utils.testing import torch_version

24
25
26
27
if torch_version() <= (1, 8, 1):
    BOUNCE_TENSORS = True
else:
    BOUNCE_TENSORS = False
28
29
30
31
32
33
34
35

CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
if torch.cuda.is_available():
    DEVICES = [CPU_DEVICES, GPU_DEVICES]
else:
    DEVICES = [CPU_DEVICES]

36
pytestmark = pytest.mark.skipif(torch_version() < (1, 8, 0), reason="requires torch version >= 1.8.0")
37
38
39


def rpc_worker(rank, world_size, init_file, func, *args):
40
41
42
    if torch_version() == (1, 8, 0):
        if torch.cuda.is_available():
            # Workaround for https://github.com/pytorch/pytorch/issues/53844
43
44
            options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file, _transports=["ibv", "uv"])
        else:
45
46
47
48
49
50
51
            # Workaround for https://github.com/pytorch/pytorch/issues/54266
            options = rpc.TensorPipeRpcBackendOptions(
                init_method="file://" + init_file,
                _channels=["mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic"],
            )
    else:
        options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
52
53
54
55
    if torch_version() > (1, 8, 1):
        for i in range(world_size):
            if i != rank:
                options.set_device_map("worker" + str(i), {rank: i})
56
57
58
59
60
61
62
    rpc.init_rpc(
        "worker" + str(rank),
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.TENSORPIPE,
        rpc_backend_options=options,
    )
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    if rank == 0:
        func(*args)
    rpc.shutdown()


def rpc_test(world_size=1):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            mp.spawn(rpc_worker, args=(world_size, tempfile.mkstemp()[1], func, *kwargs.values()), nprocs=world_size)

        globals()["test_" + func.__name__] = wrapper
        return func

    return decorator


@rpc_test()
@pytest.mark.parametrize("devices", DEVICES)
def create(devices):
    model = [("linear", nn.Linear, (4, 4), {})]
    pipe = MultiProcessPipe(model, balance=[1], chunks=1, devices=devices[:1])


@rpc_test()
def create_multiple_layers():
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=["worker0/cpu", "worker0/cpu"])


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def create_multiple_workers(devices):
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=devices[:2])


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def parameter_rrefs(devices):
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=devices[:2])
    parameter_rrefs = pipe.parameter_rrefs()
    assert len(parameter_rrefs) == 2


@rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES)
def forward(devices):
    yh = torch.tensor([1.0, 0.0])
    x = torch.tensor([1.0, -1.0])
    model = [("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1], chunks=1, devices=devices[:1])
    y = pipe(x).to_here().cpu()
    assert torch.equal(y, yh), f"{y} != {yh}"


@rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES)
def forward_chunks(devices):
    yh = torch.tensor([1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0])
    x = torch.tensor([1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0])
    model = [("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1], chunks=4, devices=devices[:1])
    y = pipe(x).to_here().cpu()
    assert torch.equal(y, yh), f"{y} != {yh}"


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
133
134
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def forward_multi(devices, checkpoint):
135
    device = devices[0].split("/")[1]
136
137
    torch.random.manual_seed(3)
    torch.cuda.manual_seed_all(3)
138
    x = torch.randn(8, 4).to(device)
139
    x.requires_grad = True  # TODO(msb) remove this limitation
140
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
141
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2], checkpoint=checkpoint)
142
143
144
145
146
147
148
149
150
151
152
153
154
    if BOUNCE_TENSORS:
        y = pipe(x).remote().cpu().to_here()
    else:
        y = pipe(x).to_here()
    expected_sum = torch.tensor(5.0615)
    assert y.shape == torch.Size([8, 4])
    assert y.requires_grad is True
    assert torch.allclose(y.sum(), expected_sum), f"{y.sum()} != {expected_sum}"


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def backward(devices):
155
    device = devices[0].split("/")[1]
156
157
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
158
    x = torch.randn(8, 4).to(device)
159
160
161
162
163
164
165
166
167
168
169
170
171
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
    with dist_autograd.context() as context_id:
        y = pipe(x)
        loss = criterion(y, rpc.RRef(x))
        loss.backward(context_id)
        grads = dist_autograd.get_gradients(context_id)
    assert len(grads) == 2


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def update(devices):
172
    device = devices[0].split("/")[1]
173
174
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
175
    x = torch.randn(8, 4).to(device)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
    params = pipe.parameter_rrefs()
    opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,)
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"