test_phony.py 1.91 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from fairscale.nn.pipe.phony import get_phony


def test_phony_size():
    p = get_phony(torch.device("cpu"), requires_grad=False)
27
    assert p.size() == (1,)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


def test_phony_requires_grad():
    p1 = get_phony(torch.device("cpu"), requires_grad=True)
    p2 = get_phony(torch.device("cpu"), requires_grad=False)
    assert p1.requires_grad
    assert not p2.requires_grad


def test_cached_phony():
    p1 = get_phony(torch.device("cpu"), requires_grad=True)
    p2 = get_phony(torch.device("cpu"), requires_grad=True)
    assert p1 is p2

    p3 = get_phony(torch.device("cpu"), requires_grad=False)
    p4 = get_phony(torch.device("cpu"), requires_grad=False)
    assert p3 is p4

    assert p1 is not p3


def test_phony_in_autograd_function():
    class Phonify(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            phony = get_phony(input.device, requires_grad=False)
            return phony.detach()

    x = torch.rand(1, requires_grad=True)

    p1 = Phonify.apply(x)
    p2 = get_phony(torch.device("cpu"), requires_grad=True)

    assert p1 is not p2
    assert p1.grad_fn is not None
    assert p2.grad_fn is None