test_grad_bucket.py 1.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import pytest
import torch

from fairscale.nn.misc import GradBucket


def test_grad_values_conserved():
    with torch.no_grad():  # remove a warning
        param = torch.rand((2, 3), requires_grad=True)
        param.grad = torch.rand(2, 3)

        bucket = GradBucket(10, param.dtype, param.device, -1)
        param_ = param.clone()

        bucket.add_grad(param_)
        torch.allclose(param.grad, param_.grad)


def test_memory_leak():
    with torch.no_grad():  # remove a warning
        param = torch.rand((2, 3), requires_grad=True)
        param.grad = torch.rand(2, 3)

        bucket = GradBucket(300, param.dtype, param.device, -1)
        bucket.add_grad(param)
        bucket.shrink()

34
35
36
37
38
39
        storage = bucket.buffer.storage()
        # See https://github.com/pytorch/pytorch/pull/59671/
        if hasattr(storage, "nbytes"):
            assert storage.nbytes() == 6 * bucket.buffer.element_size()
        else:
            assert len(storage) == 6
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


def test_max_size():
    with torch.no_grad():  # remove a warning
        param = torch.rand((20, 30), requires_grad=True)
        param.grad = torch.rand(20, 30)

        bucket = GradBucket(5, param.dtype, param.device, -1)
        with pytest.raises(AssertionError):
            bucket.add_grad(param)


def test_collapse():
    with torch.no_grad():  # remove a warning
        size = (5, 6)
        param = torch.rand(size, requires_grad=True)
        param.grad = torch.rand(size)

        bucket = GradBucket(300, param.dtype, param.device, -1)
        bucket.add_grad(param)
        bucket.shrink()
        bucket.collapse()

63
        assert bucket.buffer.numel() == 0
64
65
66
67
68
        assert param.grad is None
        bucket.rebuild()

        assert param.grad is not None
        torch.allclose(param.grad, torch.zeros(size))