test_inplace.py 3.4 KB
Newer Older
Tom Birch's avatar
Tom Birch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from torch import nn

24
from fairscale.nn.pipe import AsyncPipe
25
from fairscale.utils.testing import get_worker_map, torch_spawn
Tom Birch's avatar
Tom Birch committed
26
27
28
29


@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
30
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
31
def inplace_on_requires_grad(pipe_class):
Tom Birch's avatar
Tom Birch committed
32
    model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
33
    model = pipe_class(model, [1, 1], worker_map=get_worker_map(), checkpoint="always")
Tom Birch's avatar
Tom Birch committed
34
35

    x = torch.rand(1)
36

37
38
    if pipe_class == AsyncPipe and model.group.rank() == 0:
        # With AsyncPipe, model will wait forever for gradients if not eval
39
40
        model.eval()

Tom Birch's avatar
Tom Birch committed
41
42
43
44
45
46
47
48
49
50
51
52
    y = model(x)

    message = r"a leaf Variable that requires grad .* used in an in-place operation."
    if model.group.rank() == 1:
        with pytest.raises(RuntimeError, match=message):
            y.backward()

    torch.distributed.barrier()


@torch_spawn([1])
@pytest.mark.xfail(strict=True)
53
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
54
def inplace_on_not_requires_grad(pipe_class):
Tom Birch's avatar
Tom Birch committed
55
56
57
    # In-place operation on a tensor not requiring grad doesn't cause a
    # RuntimeError. Currently, we cannot detect this case.
    model = nn.Sequential(nn.ReLU(inplace=True))
58
    model = pipe_class(model, [1], worker_map=get_worker_map(), checkpoint="always")
Tom Birch's avatar
Tom Birch committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    x = torch.rand(1)
    y = model(x)
    del model

    message = r"a leaf Variable that requires grad .* used in an in-place operation."
    with pytest.raises(RuntimeError, match=message):
        y.backward()

    torch.distributed.barrier()


@torch_spawn([1])
@pytest.mark.xfail(strict=True)
73
@pytest.mark.parametrize("pipe_class", [AsyncPipe])
74
def inplace_incorrect_grad(pipe_class):
Tom Birch's avatar
Tom Birch committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    class M(nn.Module):
        def forward(self, foo_bar):
            # 'foo' requires grad but 'bar' does not. In-place operation on
            # 'bar' won't cause a RuntimeError.
            foo, bar = foo_bar

            # add_(1) is not idempotent, in contrast to relu_(). If it is
            # executed multiple times, it will accumulates each difference onto
            # 'bar'.
            bar.add_(1)

            # 'bar' is still captured by checkpointing. 'foo' will get
            # incorrect grad.
            return foo * bar

    model = nn.Sequential(M())
91
    model = pipe_class(model, [1], worker_map=get_worker_map(), checkpoint="always")
Tom Birch's avatar
Tom Birch committed
92
93
94
95
96
97
98
99
100
101
102

    foo = torch.tensor([1.0], requires_grad=True)
    bar = torch.tensor([1.0])

    output = model((foo, bar))
    del model
    output.backward()

    # The gradient of 'foo' should be 2, but it is 3 actually because
    # bar.add_(1) was executed twice due to checkpointing.
    assert foo.grad.item() == 2.0