test_parameter.py 2.29 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from helpers.exception import assert_fail_with
from nanotron.models.base import DTypeInvariantTensor, init_on_device_and_dtype
from nanotron.parallel.parameters import NanotronParameter
from torch import nn


def test_nanotron_parameter_does_not_override_some_parameter_variable():
    param = nn.Parameter(torch.empty(3))
    assert not hasattr(param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)


def test_uncastable_tensor():
    # Test that we can create an DTypeInvariantTensor
    x = DTypeInvariantTensor(torch.randn(3, 3))
    assert isinstance(x, torch.Tensor)
    assert isinstance(x, DTypeInvariantTensor)

    # Test that we cannot modify the type of an DTypeInvariantTensor
    with assert_fail_with(RuntimeError, error_msg="Cannot convert the type of an DTypeInvariantTensor to float"):
        x = x.float()

    with assert_fail_with(RuntimeError, error_msg="Cannot convert the type of an DTypeInvariantTensor to half"):
        x = x.half()

    with assert_fail_with(RuntimeError, error_msg="Cannot change the type of an DTypeInvariantTensor"):
        x = x.to(torch.float32)

    with assert_fail_with(RuntimeError, error_msg="Cannot change the type of an DTypeInvariantTensor"):
        x = x.to(dtype=torch.float32)

    # Test that we can modify the value of an DTypeInvariantTensor
    x[0, 0] = 1
    assert x[0, 0] == 1

    # Test that we can modify the device of an DTypeInvariantTensor
    x = x.to("cuda")
    assert x.device.type == "cuda"


def test_register_buffer_does_not_update_uncastable_tensor():
    old_device = torch.device("cuda")
    old_dtype = torch.float32
    new_device = torch.device("cpu")
    new_dtype = torch.bfloat16
    with init_on_device_and_dtype(device=new_device, dtype=new_dtype):
        module = torch.nn.Module()
        # Test that we can register an DTypeInvariantTensor as a buffer
        tensor = DTypeInvariantTensor(torch.randn(3, 4, dtype=old_dtype, device=old_device))
        module.register_buffer("buffer", tensor)

        # Test that we can modify the buffer
        module.buffer[0, 0] = 1
        assert module.buffer[0, 0] == 1

        # Test that device has been updated
        assert module.buffer.device.type == new_device.type

        # Test that dtype has not been modified
        assert module.buffer.dtype is old_dtype