test_multiprocess_pipe.py 5.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
Testing MultiProcessPipe Module
"""

import functools
import tempfile

import pytest
import torch
import torch.distributed.autograd as dist_autograd
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn

from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe
from fairscale.utils.testing import torch_version

CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
if torch.cuda.is_available():
    DEVICES = [CPU_DEVICES, GPU_DEVICES]
else:
    DEVICES = [CPU_DEVICES]

31
pytestmark = pytest.mark.skipif(torch_version() < (1, 9, 0), reason="requires torch version >= 1.9.0")
32
33
34


def rpc_worker(rank, world_size, init_file, func, *args):
35
36
37
    options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
    for i in range(world_size):
        options.set_device_map("worker" + str(i), {rank: i})
38
39
40
41
42
43
44
    rpc.init_rpc(
        "worker" + str(rank),
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.TENSORPIPE,
        rpc_backend_options=options,
    )
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if rank == 0:
        func(*args)
    rpc.shutdown()


def rpc_test(world_size=1):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            mp.spawn(rpc_worker, args=(world_size, tempfile.mkstemp()[1], func, *kwargs.values()), nprocs=world_size)

        globals()["test_" + func.__name__] = wrapper
        return func

    return decorator


@rpc_test()
@pytest.mark.parametrize("devices", DEVICES)
def create(devices):
    model = [("linear", nn.Linear, (4, 4), {})]
    pipe = MultiProcessPipe(model, balance=[1], chunks=1, devices=devices[:1])


@rpc_test()
def create_multiple_layers():
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=["worker0/cpu", "worker0/cpu"])


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def create_multiple_workers(devices):
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=devices[:2])


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def parameter_rrefs(devices):
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=devices[:2])
    parameter_rrefs = pipe.parameter_rrefs()
    assert len(parameter_rrefs) == 2


@rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES)
def forward(devices):
94
    device = devices[0].split("/")[1]
95
    yh = torch.tensor([1.0, 0.0])
96
    x = torch.tensor([1.0, -1.0]).to(device)
97
98
99
100
101
102
103
104
105
    model = [("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1], chunks=1, devices=devices[:1])
    y = pipe(x).to_here().cpu()
    assert torch.equal(y, yh), f"{y} != {yh}"


@rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES)
def forward_chunks(devices):
106
    device = devices[0].split("/")[1]
107
    yh = torch.tensor([1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0])
108
    x = torch.tensor([1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0]).to(device)
109
110
111
112
113
114
115
116
    model = [("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1], chunks=4, devices=devices[:1])
    y = pipe(x).to_here().cpu()
    assert torch.equal(y, yh), f"{y} != {yh}"


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
117
118
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def forward_multi(devices, checkpoint):
119
    device = devices[0].split("/")[1]
120
121
    torch.random.manual_seed(3)
    torch.cuda.manual_seed_all(3)
122
    x = torch.randn(8, 4).to(device)
123
    x.requires_grad = True  # TODO(msb) remove this limitation
124
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
125
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2], checkpoint=checkpoint)
126
    y = pipe(x).to_here()
127
128
129
130
131
132
133
134
135
    expected_sum = torch.tensor(5.0615)
    assert y.shape == torch.Size([8, 4])
    assert y.requires_grad is True
    assert torch.allclose(y.sum(), expected_sum), f"{y.sum()} != {expected_sum}"


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def backward(devices):
136
    device = devices[0].split("/")[1]
137
138
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
139
    x = torch.randn(8, 4).to(device)
140
141
142
143
144
145
146
147
148
149
150
151
152
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
    with dist_autograd.context() as context_id:
        y = pipe(x)
        loss = criterion(y, rpc.RRef(x))
        loss.backward(context_id)
        grads = dist_autograd.get_gradients(context_id)
    assert len(grads) == 2


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def update(devices):
153
    device = devices[0].split("/")[1]
154
155
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
156
    x = torch.randn(8, 4).to(device)
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
    pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
    params = pipe.parameter_rrefs()
    opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,)
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"