test_pipe.py 23 KB
Newer Older
Tom Birch's avatar
Tom Birch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from copy import deepcopy
import os
import time

import pytest
import torch
from torch import nn

29
30
31
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import AsyncPipe
from fairscale.nn.pipe.types import LazyModule
32
33
from fairscale.utils import torch_version
from fairscale.utils.testing import get_worker_map, torch_spawn
Tom Birch's avatar
Tom Birch committed
34
35
36


@torch_spawn([2])
37
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
38
def parameters(pipe_class):
Tom Birch's avatar
Tom Birch committed
39
    model = nn.Sequential(nn.Linear(1, 1))
40
    pipe = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1)
Tom Birch's avatar
Tom Birch committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    if torch.distributed.get_rank() == 0:
        assert list(pipe.parameters()) != []
    else:
        assert list(pipe.parameters()) == []


@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def infiniband():
    if torch.distributed.get_rank() == 0:
        t = torch.Tensor(range(100)).cuda()
        torch.distributed.broadcast(t, 0)
    else:
        t = torch.empty(100).cuda()
        torch.distributed.broadcast(t, 0)

    assert torch.equal(t, torch.Tensor(range(100)).cuda())
    print(f"t on {torch.distributed.get_rank()} is {t}")


@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def infiniband2():
    if torch.distributed.get_rank() == 0:
        t = torch.Tensor(range(100)).cuda()
67
        torch.distributed.send(t, 1, group=get_pipeline_parallel_group())
Tom Birch's avatar
Tom Birch committed
68
69
    else:
        t = torch.empty(100).cuda()
70
        torch.distributed.recv(t, 0, group=get_pipeline_parallel_group())
Tom Birch's avatar
Tom Birch committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    assert torch.equal(t, torch.Tensor(range(100)).cuda())
    print(f"t on {torch.distributed.get_rank()} is {t}")


@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def infiniband3():
    t = torch.Tensor(range(100)).cuda()
    torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.SUM)
    assert torch.equal(t, torch.Tensor(range(0, 200, 2)).cuda())


@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required")
def mpi():
    seed = 1234
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    torch.distributed.barrier()
    tensor_size = (1024, 1024, 10)
    torch.cuda.set_device(torch.distributed.get_rank())  # need to pin device or ucx gets unhappy

    if torch.distributed.get_rank() == 0:
        # t = torch.Tensor(range(10)).cuda(0)
        t = torch.rand(*tensor_size).cuda(0)
        torch.distributed.send(t, 1, tag=1234)
    else:
        t = torch.empty(*tensor_size).cuda(1)
        torch.distributed.recv(t, 0, tag=1234)
        t2 = torch.rand(*tensor_size).cuda(1)

        assert torch.equal(t, t2)


@torch_spawn([1])
108
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
109
def public_attrs(pipe_class):
Tom Birch's avatar
Tom Birch committed
110
111
    model = nn.Sequential(nn.Linear(1, 1))

112
    pipe = pipe_class(model, balance=(1,), worker_map=get_worker_map(), chunks=42, checkpoint="always",)
Tom Birch's avatar
Tom Birch committed
113
114
115
116
117
118
119
120
121
122

    assert pipe.balance == [1]
    assert pipe.chunks == 42
    assert isinstance(pipe.chunks, int)
    assert pipe.checkpoint == "always"
    assert isinstance(pipe.checkpoint, str)


@torch_spawn([2])
@pytest.mark.parametrize("balance", [[2], [1, 1]])
123
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
124
def sequential_like(balance, pipe_class):
Tom Birch's avatar
Tom Birch committed
125
126
127
128
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
129
    model = pipe_class(model, balance, worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    if balance == [2]:
        if torch.distributed.get_rank() == 0:
            assert len(model) == 2
            assert list(model) == [a, b]

            assert model[0] is a
            assert model[1] is b
            with pytest.raises(IndexError):
                _ = model[2]

            assert model[-1] is b
            assert model[-2] is a
        else:
            assert len(model) == 0
            assert list(model) == []
    else:
        assert len(model) == 1
        if torch.distributed.get_rank() == 0:
            assert list(model) == [a]
            assert model[0] is a
            assert model[-1] is a
        else:
            assert list(model) == [b]
            assert model[0] is b
            assert model[-1] is b

        with pytest.raises(IndexError):
            _ = model[1]


@torch_spawn([1])
162
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
163
def balance_wrong_length(pipe_class):
Tom Birch's avatar
Tom Birch committed
164
165
166
167
168
169
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)

    with pytest.raises(ValueError):
170
        pipe_class(model, balance=[1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
171
172

    with pytest.raises(ValueError):
173
        pipe_class(model, balance=[3], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
174
175
176


@torch_spawn([2])
177
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
178
def balance_less_than_1(pipe_class):
Tom Birch's avatar
Tom Birch committed
179
180
181
182
183
184
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)

    with pytest.raises(ValueError):
185
        pipe_class(model, balance=[0, 2], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
186
187

    with pytest.raises(ValueError):
188
        pipe_class(model, balance=[-1, 3], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
189
190
191


@torch_spawn([1])
192
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
193
def chunks_less_than_1(pipe_class):
Tom Birch's avatar
Tom Birch committed
194
195
196
    model = nn.Sequential(nn.Linear(1, 1))

    with pytest.raises(ValueError):
197
        pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=0)
Tom Birch's avatar
Tom Birch committed
198
199

    with pytest.raises(ValueError):
200
        pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=-1)
Tom Birch's avatar
Tom Birch committed
201
202
203


@torch_spawn([1])
204
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
205
def too_few_devices(pipe_class):
Tom Birch's avatar
Tom Birch committed
206
207
208
209
    model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))

    with pytest.raises(IndexError):
        # len(balance) > len(group.size())
210
        model = pipe_class(model, balance=[1, 1, 1, 1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
211
212
213


@torch_spawn([1])
214
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
215
def batch_size_indivisible(pipe_class):
Tom Birch's avatar
Tom Birch committed
216
    model = nn.Sequential(nn.Linear(1, 1))
217
    model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=4)
Tom Birch's avatar
Tom Birch committed
218
219
220
221
222
223
224
225
226

    with pytest.warns(None) as record:
        model(torch.rand(7, 1))

    # Indivisible batch size is legal.
    assert not record


@torch_spawn([1])
227
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
228
def batch_size_small(pipe_class):
Tom Birch's avatar
Tom Birch committed
229
    model = nn.Sequential(nn.Linear(1, 1))
230
    model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=4)
Tom Birch's avatar
Tom Birch committed
231
232
233
234
235
236
237
238
239

    with pytest.warns(None) as record:
        model(torch.rand(2, 1))

    # Batch size smaller than chunks is legal.
    assert not record


@torch_spawn([1])
240
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
241
def checkpoint_mode(pipe_class):
Tom Birch's avatar
Tom Birch committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    def count_grad_fn(grad_fn, name, visited=set()):
        if grad_fn in visited:
            return 0
        visited.add(grad_fn)

        if grad_fn is None:
            return 0
        if grad_fn.__class__.__name__ == name:
            return 1

        counter = 0
        for next_grad_fn, _ in grad_fn.next_functions:
            counter += count_grad_fn(next_grad_fn, name, visited=visited)
        return counter

    model = nn.Sequential(nn.Linear(1, 1))
    input = torch.rand(2, 1)

260
261
262
    always = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="always",)
    except_last = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="except_last",)
    never = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="never",)
Tom Birch's avatar
Tom Birch committed
263
264
265
266
267
268
269
270
271
272
273

    always_output = always(input)
    except_last_output = except_last(input)
    never_output = never(input)

    assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2
    assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1
    assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0


@torch_spawn([1])
274
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
275
def checkpoint_mode_invalid(pipe_class):
Tom Birch's avatar
Tom Birch committed
276
277
278
    model = nn.Sequential(nn.Linear(1, 1))

    with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
279
280
        pipe_class(
            model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="INVALID_CHECKPOINT",
Tom Birch's avatar
Tom Birch committed
281
282
283
284
        )


@torch_spawn([1])
285
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
286
def checkpoint_mode_when_chunks_1(pipe_class):
Tom Birch's avatar
Tom Birch committed
287
288
289
    model = nn.Sequential(nn.Linear(1, 1))

    # All checkpoint modes are fine.
290
291
    pipe_class(
        model, balance=[1], worker_map=get_worker_map(), chunks=1, checkpoint="except_last",
292
    )
293
294
    pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1, checkpoint="always")
    pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1, checkpoint="never")
Tom Birch's avatar
Tom Birch committed
295
296
297


@torch_spawn([1])
298
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
299
def checkpoint_eval(pipe_class):
Tom Birch's avatar
Tom Birch committed
300
    model = nn.Sequential(nn.Linear(1, 1))
301
    model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,)
Tom Birch's avatar
Tom Birch committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    input = torch.rand(2, 1)

    def find_grad_fn(grad_fn, name):
        if grad_fn is None:
            return False
        if grad_fn.__class__.__name__ == name:
            return True
        for next_grad_fn, _ in grad_fn.next_functions:
            if find_grad_fn(next_grad_fn, name):
                return True
        return False

    model.train()
    train_output = model(input)
    assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
    assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")

    model.eval()
    eval_output = model(input)
    assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
    assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")


@torch_spawn([2])
326
@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True)
327
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
328
def checkpoint_non_float_input(pipe_class):
Tom Birch's avatar
Tom Birch committed
329
330
331
332
333
334
335
336
337
    class ForkNonFloat(nn.Module):
        def forward(self, input):
            return (input * 2, torch.tensor([False]))

    class JoinNonFloat(nn.Module):
        def forward(self, input):
            return input[0] * 2

    model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
338
    model = pipe_class(model, balance=[1, 1], worker_map=get_worker_map(), chunks=1, checkpoint="always",)
Tom Birch's avatar
Tom Birch committed
339
340
341
342
343
344
345

    input = torch.rand(1, requires_grad=True)
    output = model(input)
    if model.group.rank() == 1:
        # with torch.autograd.detect_anomaly():
        output.backward()

346
347
    torch.distributed.barrier()

Tom Birch's avatar
Tom Birch committed
348
349

@torch_spawn([1])
350
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
351
def no_grad(pipe_class):
Tom Birch's avatar
Tom Birch committed
352
    model = nn.Sequential(nn.Linear(1, 1))
353
    model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2)
Tom Birch's avatar
Tom Birch committed
354
355
356
357
358
359
360
361
362
363
364
    input = torch.rand(2, 1)

    latent = None

    def hook(module, input, output):
        _ = module
        _ = input

        nonlocal latent
        latent = output

365
366
    partition = model.partition
    partition.register_forward_hook(hook)
Tom Birch's avatar
Tom Birch committed
367
368
369
370
371
372
373
374

    with torch.no_grad():
        model(input)

    assert latent.grad_fn is None


@torch_spawn([1])
375
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
376
def exception(pipe_class):
Tom Birch's avatar
Tom Birch committed
377
378
379
380
381
382
383
384
    class ExpectedException(Exception):
        pass

    class Raise(nn.Module):
        def forward(self, *_):
            raise ExpectedException()

    model = nn.Sequential(Raise())
385
    model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1)
Tom Birch's avatar
Tom Birch committed
386
387
388
389
390
391
392

    with pytest.raises(ExpectedException):
        model(torch.rand(1))


# FIXME(tom) should probably signal to all hosts in group to stop
@torch_spawn([4])
393
@pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs")
Tom Birch's avatar
Tom Birch committed
394
@pytest.mark.xfail(strict=True)
395
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
396
def exception_early_stop_asap(pipe_class):
Tom Birch's avatar
Tom Birch committed
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    """Even the first partitions have finished to process, the partition before
    the failed partition hould be killed as soon as possible.
    """

    class ExpectedExceptio(Exception):
        pass

    class Pass(nn.Module):
        def forward(self, x):
            return x

    counter = 0

    class Counter(nn.Module):
        def forward(self, x):
            time.sleep(0.1)

            nonlocal counter
            counter += 1

            return x

    class Raise(nn.Module):
        def forward(self, x):
            raise ExpectedException()

    model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
424
    model = pipe_class(model, [1, 1, 1, 1], worker_map=get_worker_map(), chunks=3)
Tom Birch's avatar
Tom Birch committed
425
426
427
428
429
430
431
432
433

    with pytest.raises(ExpectedException):
        model(torch.rand(3))

    # If the early stop doesn't work, it would be 3 instead.
    assert counter == 2


@torch_spawn([1])
434
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
435
def input_pair(pipe_class):
Tom Birch's avatar
Tom Birch committed
436
437
438
439
440
441
442
443
444
445
446
    class Two(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc_a = nn.Linear(1, 1)
            self.fc_b = nn.Linear(1, 1)

        def forward(self, a_and_b):
            a, b = a_and_b
            return (self.fc_a(a), self.fc_b(b))

    model = nn.Sequential(Two())
447
    model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,)
Tom Birch's avatar
Tom Birch committed
448
449
450
451
452
453
454
455
456
457
458
459
460

    a = torch.rand(10, 1, requires_grad=True)
    b = torch.rand(10, 1, requires_grad=True)

    a_out, b_out = model((a, b))
    loss = (a_out + b_out).mean()
    loss.backward()

    assert a.grad is not None
    assert b.grad is not None


@torch_spawn([1])
461
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
462
def input_singleton(pipe_class):
Tom Birch's avatar
Tom Birch committed
463
464
465
466
467
468
469
470
471
472
    class One(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(1, 1)

        def forward(self, only_a):
            (a,) = only_a
            return (self.fc(a),)

    model = nn.Sequential(One())
473
    model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,)
Tom Birch's avatar
Tom Birch committed
474
475
476
477
478
479
480
481
482
483
484
485

    a = torch.rand(10, 1, requires_grad=True)

    (a_out,) = model((a,))
    loss = a_out.mean()
    loss.backward()

    assert all(p.grad is not None for p in model.parameters())
    assert a.grad is not None


@torch_spawn([1])
486
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
487
def input_varargs(pipe_class):
Tom Birch's avatar
Tom Birch committed
488
    model = nn.Sequential(nn.Linear(1, 1))
489
    model = pipe_class(model, balance=[1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
490
491
492
493
494
495
496
497
498
499

    a = torch.rand(1)
    b = torch.rand(1)

    # TypeError: forward() takes 2 positional arguments but 3 were given
    with pytest.raises(TypeError):
        model(a, b)


@torch_spawn([1])
500
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
501
def non_tensor(pipe_class):
Tom Birch's avatar
Tom Birch committed
502
503
504
505
506
    class NonTensor(nn.Module):
        def forward(self, _):
            return "hello"

    model = nn.Sequential(NonTensor())
507
    model = pipe_class(model, balance=[1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
508
509
510
511
512
513
514
515
516
517
518
519
    x = torch.rand(1)

    # TypeError: expected Tensor as element 0 in argument 0, but got str
    with pytest.raises(TypeError):
        model(x)

    # TypeError: expected Tensor to scatter, but got str
    with pytest.raises(TypeError):
        model("hello")


@torch_spawn([1])
520
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
521
def non_tensor_tuple(pipe_class):
Tom Birch's avatar
Tom Birch committed
522
523
524
525
526
    class NonTensorTuple(nn.Module):
        def forward(self, x):
            return (x, "hello")

    model = nn.Sequential(NonTensorTuple())
527
    model = pipe_class(model, balance=[1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    x = torch.rand(1)

    # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
    with pytest.raises(TypeError):
        model(x)

    # TypeError: expected Tensor to scatter, but got str
    with pytest.raises(TypeError):
        model((x, "hello"))


@torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("lazy", [True, False])
542
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
543
def deferred_batch_norm(checkpoint, lazy, pipe_class):
Tom Birch's avatar
Tom Birch committed
544
545
546
547
    bn = nn.BatchNorm2d(3)
    pipe_bn = deepcopy(bn)
    pipe_fn = lambda: pipe_bn  # noqa: E731
    if lazy:
548
        model = [LazyModule(pipe_fn)]
Tom Birch's avatar
Tom Birch committed
549
550
    else:
        model = nn.Sequential(pipe_bn)
551
552
    pipe = pipe_class(
        model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True,
Tom Birch's avatar
Tom Birch committed
553
554
555
556
557
558
559
560
561
562
563
564
565
    )

    x = torch.rand(4, 3, 10, 10)
    pipe(x).mean().backward()
    bn(x).mean().backward()

    assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4)
    assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)


@torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always"])
@pytest.mark.parametrize("lazy", [True, False])
566
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
567
def deferred_batch_norm_params(checkpoint, lazy, pipe_class):
Tom Birch's avatar
Tom Birch committed
568
569
570
571
    bn = nn.BatchNorm2d(3)
    pipe_bn = deepcopy(bn)
    pipe_fn = lambda: pipe_bn  # noqa: E731
    if lazy:
572
        model = [LazyModule(pipe_fn)]
Tom Birch's avatar
Tom Birch committed
573
574
    else:
        model = nn.Sequential(pipe_bn)
575
576
    pipe = pipe_class(
        model, balance=[1], worker_map=get_worker_map(), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True,
Tom Birch's avatar
Tom Birch committed
577
578
579
580
581
582
583
584
585
586
587
588
589
    )

    x = torch.rand(4, 3, 10, 10)
    pipe(x).mean().backward()
    bn(x).mean().backward()

    assert pipe[0].weight.grad is not None
    assert pipe[0].bias.grad is not None

    assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
    assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)


590
@torch_spawn([4])
591
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
592
def devices(pipe_class):
Tom Birch's avatar
Tom Birch committed
593
594
595
596
597
598
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)
    c = nn.Linear(1, 1)

    # There are extra two ranks.
    model = nn.Sequential(a, b, c)
599
    model = pipe_class(model, [1, 1, 1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
600
601

    # Extra devices must be discarded.
602
    if model.group.rank() == 3:
Tom Birch's avatar
Tom Birch committed
603
604
605
606
        assert model.pipeline is None


@torch_spawn([2])
607
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
608
def partitions(pipe_class):
Tom Birch's avatar
Tom Birch committed
609
610
611
612
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
613
    model = pipe_class(model, [1, 1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
614

615
    assert isinstance(model.partition, nn.Sequential)
Tom Birch's avatar
Tom Birch committed
616

617
    if model.group.rank() == 0:
618
        assert model[0].weight == a.weight
619
    else:
620
        assert model[0].weight == b.weight
Tom Birch's avatar
Tom Birch committed
621
622
623
624


@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
625
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
626
def deny_moving(pipe_class):
Tom Birch's avatar
Tom Birch committed
627
628
629
630
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(a, b)
631
    model = pipe_class(model, [1, 1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648

    model.cuda()
    model.cpu()
    model.to(torch.device("cuda"))
    model.to(0)
    model.to("cuda")
    model.to(device=0)
    model.to(torch.rand(1))
    model.to(tensor=torch.rand(1))

    # Casting is allowed.
    model.half()
    model.to(torch.double)
    model.to(dtype=torch.float)


@torch_spawn([1])
649
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
650
def empty_module(pipe_class):
Tom Birch's avatar
Tom Birch committed
651
652
    # Empty sequential module is not illegal.
    model = nn.Sequential()
653
    model = pipe_class(model, [], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
654
655
656
657

    assert model(torch.tensor([42])) == torch.tensor([42])
    assert model((torch.tensor([42]),)) == (torch.tensor([42]),)

658
    # But only tensor or tensors is legal in MultiProcessPipe.
Tom Birch's avatar
Tom Birch committed
659
660
661
662
663
664

    with pytest.raises(TypeError):
        model(42)


@torch_spawn([2])
665
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
666
@pytest.mark.skip(reason="TODO(msb) handle named_children")
667
def named_children(pipe_class):
Tom Birch's avatar
Tom Birch committed
668
669
670
671
    a = nn.Linear(1, 1)
    b = nn.Linear(1, 1)

    model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
672
    model = pipe_class(model, [1, 1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
673
674

    names = set(n for n, _ in model.named_modules())
675
676
677
678
    if model.group.rank() == 0:
        assert "0.a" in names
    else:
        assert "0.b" in names
Tom Birch's avatar
Tom Birch committed
679

680
    # MultiProcessPipe doesn't support __getattr__. Unlike nn.Sequential, MultiProcessPipe requires
Tom Birch's avatar
Tom Birch committed
681
682
683
684
685
686
    # several methods in its namespace.
    with pytest.raises(AttributeError):
        model.a


@torch_spawn([1])
687
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
688
def recommend_auto_balance(pipe_class):
689
    with pytest.raises(ValueError):
Tom Birch's avatar
Tom Birch committed
690
        # module and sum of balance have differen length (module: 0, sum of balance: 1)
691
        pipe_class(nn.Sequential(), [1])
Tom Birch's avatar
Tom Birch committed
692

693
    with pytest.raises(ValueError):
Tom Birch's avatar
Tom Birch committed
694
        # module and sum of balance have different length (module: 2, sum of balance: 1)
695
        pipe_class(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
Tom Birch's avatar
Tom Birch committed
696
697
698


@torch_spawn([2])
699
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
700
def lazy_construction(pipe_class):
Tom Birch's avatar
Tom Birch committed
701
702
703
704
705
706
707
708
709
710
711
712
    init_count = 0

    class Custom(nn.Module):
        def __init__(self):
            super(Custom, self).__init__()
            nonlocal init_count
            init_count += 1

        def forward(self, x):
            return x

    model = [
713
714
715
716
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
Tom Birch's avatar
Tom Birch committed
717
718
    ]

719
    pipe = pipe_class(model, balance=[2, 2], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
720
721
722
723
724
725
726
727

    assert isinstance(pipe[0], Custom)
    assert isinstance(pipe[1], Custom)
    assert len(pipe) == 2
    assert init_count == 2


@torch_spawn([2])
728
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi")
729
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
730
def missing_worker_map(pipe_class):
Tom Birch's avatar
Tom Birch committed
731
732
    model = nn.Sequential(nn.ReLU(), nn.ReLU())

733
    with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"):
734
        pipe_class(model, [1, 1])
Tom Birch's avatar
Tom Birch committed
735
736
737
738


@torch_spawn([2])
@pytest.mark.skip(reason="currently broken")
739
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
740
def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
Tom Birch's avatar
Tom Birch committed
741
742
743
744
745
746
747
748
749
750
    class Surrogate(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

    conv = nn.Conv2d(3, 3, 1)
    model = nn.Sequential(Surrogate(conv), Surrogate(conv))

    # FIXME(tom) can't have duplicate params with separate processes
    with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"):
751
        pipe_class(model, [1, 1], worker_map=get_worker_map())
Tom Birch's avatar
Tom Birch committed
752
753


754
755
756
757
@torch_spawn([4])
def async_event_loop():

    model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU())
758
    pipe = AsyncPipe(model, [1, 1, 1, 1], worker_map=get_worker_map(), chunks=10)
759
760
761
762
763
764
765

    inputs = torch.rand(100, 10)

    output = pipe(inputs)
    if pipe.final_stage:
        loss = output.mean()
        loss.backward()