test_param_bucket.py 1.48 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import pytest
import torch

from fairscale.nn.misc import ParamBucket


def test_param_values_conserved():
    param = torch.rand((2, 3))

    bucket = ParamBucket(10, param.dtype, param.device)
    param_ = param.clone()

    bucket.add_param(param_)
    torch.allclose(param, param_)


def test_max_size():
    param = torch.rand((20, 30))

    bucket = ParamBucket(5, param.dtype, param.device)
    with pytest.raises(AssertionError):
        bucket.add_param(param)


def test_double_check_int():
    param = torch.rand((5, 6))

    bucket = ParamBucket(300, param.dtype, param.device)
    bucket.add_param(param)

    with pytest.raises(AssertionError):
        bucket.add_param(param)


def test_type_change():
    size = (5, 6)
    param = torch.rand(size, requires_grad=True)
    param_ = param.clone()

    bucket = ParamBucket(30, param.dtype, param.device)
    bucket.add_param(param)

    # Move the bucket to fp16 and back
    bucket.to(dtype=torch.float16, device=param.device)
51
52
    assert bucket.buffer.dtype == torch.float16

53
    bucket.to(dtype=torch.float32, device=param.device, keep_param_alignment=True)
54
    assert bucket.buffer.dtype == torch.float32
55
56
57
58
59
60

    # Same with the reference tensor
    param_.to(dtype=torch.float16)
    param_.to(dtype=torch.float32)

    torch.allclose(param, param_)