test_multiprocess_pipe.py 11.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
Testing MultiProcessPipe Module
"""

import functools
import tempfile
12
from typing import Any, Dict, List, NamedTuple, Tuple
13
14
15
16

import pytest
import torch
import torch.distributed.autograd as dist_autograd
17
from torch.distributed.nn import RemoteModule
18
19
20
21
22
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn

23
from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
24
from fairscale.utils import torch_version
25

26
27
28
29
30
pytestmark = pytest.mark.skipif(
    not torch.cuda.is_available() or torch_version() < (1, 9, 0),
    reason="CPU tests fail right now and all tests require torch version >= 1.9.0.",
)

31
32
33
34
35
36
37
CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
if torch.cuda.is_available():
    DEVICES = [CPU_DEVICES, GPU_DEVICES]
else:
    DEVICES = [CPU_DEVICES]

38

39
def rpc_worker(rank, world_size, init_file, func, *args):
40
41
42
    options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
    for i in range(world_size):
        options.set_device_map("worker" + str(i), {rank: i})
43
44
45
46
47
48
49
    rpc.init_rpc(
        "worker" + str(rank),
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.TENSORPIPE,
        rpc_backend_options=options,
    )
50
51
52
53
54
    if rank == 0:
        func(*args)
    rpc.shutdown()


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class RemoteModuleParams(NamedTuple):
    module_cls: nn.Module
    args: Tuple
    kwargs: Dict[str, Any]


def create_sequence_pipeline(
    layers: List[RemoteModuleParams], balance: List[int], devices: List[str], **kwargs: Any
) -> DistributedPipeline:
    """A simple helper function to create a pipeline from list of pipeline-modules that run sequentially.
       Args:
           layers: list of modules. They should not be already assigned a remote-device.
           balance: a list of integers how layers should be paritioned. Sum of numbers in 'balance'
               should be equal to the number of layers.
           devices: specification of remote device for each partition. Should be of the same length
               as 'balance'.
    """
    remote_modules: List[RemoteModule] = []
    index = 0
    for num_layers, remote_device in zip(balance, devices):
        next_index = index + num_layers
        for li in range(index, next_index):
            remote_modules.append(RemoteModule(remote_device, **layers[li]._asdict()))
        index = next_index

    graph = PipelineModulesGraph()
81
    graph.add_sequence(remote_modules, [0])
82
83
84
85

    return DistributedPipeline(graph, **kwargs)


86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def rpc_test(world_size=1):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            mp.spawn(rpc_worker, args=(world_size, tempfile.mkstemp()[1], func, *kwargs.values()), nprocs=world_size)

        globals()["test_" + func.__name__] = wrapper
        return func

    return decorator


@rpc_test()
@pytest.mark.parametrize("devices", DEVICES)
def create(devices):
101
102
    model = [RemoteModuleParams(nn.Linear, (4, 4), {})]
    pipe = create_sequence_pipeline(model, balance=[1], chunks=1, devices=devices[:1])
103
104
105
106


@rpc_test()
def create_multiple_layers():
107
108
    model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
    pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=["worker0/cpu", "worker0/cpu"])
109
110
111
112
113


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def create_multiple_workers(devices):
114
115
    model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
    pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=devices[:2])
116
117
118
119
120


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def parameter_rrefs(devices):
121
122
    model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
    pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=devices[:2])
123
124
125
126
127
128
129
130
    parameter_rrefs = pipe.parameter_rrefs()
    assert len(parameter_rrefs) == 2


@rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES)
def forward(devices):
    yh = torch.tensor([1.0, 0.0])
131
132
133
    x = torch.tensor([1.0, -1.0])
    model = [RemoteModuleParams(nn.ReLU, (), {})]
    pipe = create_sequence_pipeline(model, balance=[1], chunks=1, devices=devices[:1])
134
135
136
137
138
139
140
141
    y = pipe(x).to_here().cpu()
    assert torch.equal(y, yh), f"{y} != {yh}"


@rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES)
def forward_chunks(devices):
    yh = torch.tensor([1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0])
142
143
144
    x = torch.tensor([1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0])
    model = [RemoteModuleParams(nn.ReLU, (), {})]
    pipe = create_sequence_pipeline(model, balance=[1], chunks=4, devices=devices[:1])
145
146
147
148
149
150
    y = pipe(x).to_here().cpu()
    assert torch.equal(y, yh), f"{y} != {yh}"


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
151
152
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def forward_multi(devices, checkpoint):
153
    device = devices[0].split("/")[1]
154
155
    torch.random.manual_seed(3)
    torch.cuda.manual_seed_all(3)
156
    x = torch.randn(8, 4).to(device)
157
    x.requires_grad = True  # TODO(msb) remove this limitation
158
159
    model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
    pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2], checkpoint=checkpoint)
160
    y = pipe(x).to_here()
161
162
163
164
165
166
167
168
169
    expected_sum = torch.tensor(5.0615)
    assert y.shape == torch.Size([8, 4])
    assert y.requires_grad is True
    assert torch.allclose(y.sum(), expected_sum), f"{y.sum()} != {expected_sum}"


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def backward(devices):
170
    device = devices[0].split("/")[1]
171
172
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
173
    x = torch.randn(8, 4).to(device)
174
175
    model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
    pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2])
176
177
178
179
180
181
182
183
184
185
186
    with dist_autograd.context() as context_id:
        y = pipe(x)
        loss = criterion(y, rpc.RRef(x))
        loss.backward(context_id)
        grads = dist_autograd.get_gradients(context_id)
    assert len(grads) == 2


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def update(devices):
187
    device = devices[0].split("/")[1]
188
189
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
190
    x = torch.randn(8, 4).to(device)
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
    pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=4, devices=devices[:2])
    opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,)
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"


class ConcatenateTensors(nn.Module):
    def forward(self, *inputs):
        return torch.cat(inputs, dim=1)


class SplitTensors(nn.Module):
    def forward(self, input):
        return torch.split(input, (input.shape[1] + 1) // 2, dim=1)


def extract_partitions(graph: PipelineModulesGraph, pipeline: DistributedPipeline) -> List[List[int]]:
    return [list(map(graph.nodes.index, p.nodes)) for p in pipeline.partitions]


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def multi_input_multi_output_layers(devices):
    device = devices[0].split("/")[1]
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
    x = torch.randn(8, 4).to(device)

    #                                / ->linear_layer_2_1
    # input -> linear_layer1 -> split                     ->concatenate
    #                                \ ->linear_layer_2_2

    linear_layer_1 = RemoteModule(devices[0], nn.Linear, (4, 4), {})
    split = RemoteModule(devices[0], SplitTensors, (), {})
    linear_layers_2 = [
        RemoteModule(devices[0], nn.Linear, (2, 2), {}),
        RemoteModule(devices[1], nn.Linear, (2, 2), {}),
    ]
    concatenate = RemoteModule(devices[1], ConcatenateTensors, ())

    graph = PipelineModulesGraph()
241
242
243
244
    graph.add_sequence([linear_layer_1, split], [0], 2)
    for i, l in enumerate(linear_layers_2):
        graph.add_layer(l, [(split, i)])
    graph.add_layer(concatenate, linear_layers_2)
245
246
247

    pipe = DistributedPipeline(graph, chunks=4)
    assert [[0, 1], [2], [3], [4]] == extract_partitions(graph, pipe)
248
249
250
    parameter_rrefs = pipe.parameter_rrefs()
    assert len(parameter_rrefs) == 6
    opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"


# A test for extracting the same graph as in test multi_input_multi_output_layers automatically
class ShardedLinearLayer(nn.Module):
    def __init__(self, input_device, shard_devices, output_device):
        super().__init__()
        self.split = RemoteModule(input_device, SplitTensors, (), {})
        self.linear_layers_2 = nn.ModuleList(
            [
                RemoteModule(shard_devices[0], nn.Linear, (2, 2), {}),
                RemoteModule(shard_devices[1], nn.Linear, (2, 2), {}),
            ]
        )
        self.concatenate = RemoteModule(output_device, ConcatenateTensors, ())

    def forward(self, input):
        shards = self.split(input)
        shards = [self.linear_layers_2[i](shards[i]) for i in range(2)]
        return self.concatenate(*shards)


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def auto_graph_extract(devices):
    from fairscale.experimental.nn.distributed_pipeline.trace import make_graph

    device = devices[0].split("/")[1]
    torch.random.manual_seed(3)
    criterion = DistributedLoss(torch.nn.MSELoss)
    x = torch.randn(8, 4).to(device)

    # create model
    model = nn.Sequential(
294
295
296
        RemoteModule(devices[0], nn.Linear, (4, 4), {}),
        ShardedLinearLayer(devices[0], devices, devices[1]),
        RemoteModule(devices[0], nn.Linear, (4, 4), {}),
297
298
299
300
    )
    graph = make_graph(model)
    pipe = DistributedPipeline(graph, chunks=4)
    partitions = extract_partitions(graph, pipe)
301
    assert [[0, 1], [2], [3], [4], [5]] == partitions, f"partitions={partitions}"
302
    parameter_rrefs = pipe.parameter_rrefs()
303
    assert len(parameter_rrefs) == 8
304
    opt = DistributedOptimizer(torch.optim.SGD, parameter_rrefs, lr=0.05,)
305
306
307
308
309
310
311
312
313
314
    losses = []
    for i in range(2):
        with dist_autograd.context() as context_id:
            y = pipe(x)
            loss = criterion(y, rpc.RRef(x))
            losses.append(loss)
            loss.backward(context_id)
            opt.step(context_id)
    losses = [l.to_here() for l in losses]
    assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"