test_diag_matrix.py 3.47 KB
Newer Older
1
import sys
2
import unittest
3
4

import backend as F
5
6
7
import pytest
import torch

8
from dgl.sparse import diag, DiagMatrix, identity
9
10
11
12
13


@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag(val_shape, mat_shape):
14
    ctx = F.ctx()
15
    # creation
16
    val = torch.randn(val_shape).to(ctx)
17
18
19
20
21
22
23
24
    mat = diag(val, mat_shape)

    # val, shape attributes
    assert torch.allclose(mat.val, val)
    if mat_shape is None:
        mat_shape = (val_shape[0], val_shape[0])
    assert mat.shape == mat_shape

25
    val = torch.randn(val_shape).to(ctx)
26
27
28
29
30
31
32
33
34

    # nnz
    assert mat.nnz == val.shape[0]
    # dtype
    assert mat.dtype == val.dtype
    # device
    assert mat.device == val.device

    # as_sparse
35
    sp_mat = mat.to_sparse()
36
37
38
39
40
41
42
43
44
45
    # shape
    assert tuple(sp_mat.shape) == mat_shape
    # nnz
    assert sp_mat.nnz == mat.nnz
    # dtype
    assert sp_mat.dtype == mat.dtype
    # device
    assert sp_mat.device == mat.device
    # row, col, val
    edge_index = torch.arange(len(val)).to(mat.device)
46
47
    row, col = sp_mat.coo()
    val = sp_mat.val
48
49
50
51
52
53
54
55
    assert torch.allclose(row, edge_index)
    assert torch.allclose(col, edge_index)
    assert torch.allclose(val, val)


@pytest.mark.parametrize("shape", [(3, 3), (3, 5), (5, 3)])
@pytest.mark.parametrize("d", [None, 2])
def test_identity(shape, d):
56
    ctx = F.ctx()
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    # creation
    mat = identity(shape, d)
    # type
    assert isinstance(mat, DiagMatrix)
    # shape
    assert mat.shape == shape
    # val
    len_val = min(shape)
    if d is None:
        val_shape = len_val
    else:
        val_shape = (len_val, d)
    val = torch.ones(val_shape)
    assert torch.allclose(val, mat.val)
71
72
73
74
75
76
77
78
79
80
81
82
83
84


def test_print():
    ctx = F.ctx()

    # basic
    val = torch.tensor([1.0, 1.0, 2.0]).to(ctx)
    A = diag(val)
    print(A)

    # vector-shape non zero
    val = torch.randn(3, 2).to(ctx)
    A = diag(val)
    print(A)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


@unittest.skipIf(
    F._default_context_str == "cpu",
    reason="Device conversions don't need to be tested on CPU.",
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_to_device(device):
    val = torch.randn(3)
    mat_shape = (3, 4)
    mat = diag(val, mat_shape)

    target_val = mat.val.to(device)
    mat2 = mat.to(device=device)
    assert mat2.shape == mat.shape
    assert torch.allclose(mat2.val, target_val)

    mat2 = getattr(mat, device)()
    assert mat2.shape == mat.shape
    assert torch.allclose(mat2.val, target_val)


@pytest.mark.parametrize(
    "dtype", [torch.float, torch.double, torch.int, torch.long]
)
def test_to_dtype(dtype):
    val = torch.randn(3)
    mat_shape = (3, 4)
    mat = diag(val, mat_shape)

    target_val = mat.val.to(dtype=dtype)
    mat2 = mat.to(dtype=dtype)
    assert mat2.shape == mat.shape
    assert torch.allclose(mat2.val, target_val)

    func_name = {
        torch.float: "float",
        torch.double: "double",
        torch.int: "int",
        torch.long: "long",
    }
    mat2 = getattr(mat, func_name[dtype])()
    assert mat2.shape == mat.shape
    assert torch.allclose(mat2.val, target_val)
129
130
131
132
133
134
135
136
137
138
139
140
141


@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag_matrix_transpose(val_shape, mat_shape):
    ctx = F.ctx()
    val = torch.randn(val_shape).to(ctx)
    mat = diag(val, mat_shape).transpose()

    assert torch.allclose(mat.val, val)
    if mat_shape is None:
        mat_shape = (val_shape[0], val_shape[0])
    assert mat.shape == mat_shape[::-1]