test_diag.py 2.03 KB
Newer Older
1
2
3
import sys

import backend as F
4
5
6
import pytest
import torch

7
from dgl.sparse import diag, DiagMatrix, identity
8

Mufei Li's avatar
Mufei Li committed
9
# TODO(#4818): Skipping tests on win.
10
11
12
13
14
15
16
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
@pytest.mark.parametrize("mat_shape", [None, (3, 5), (5, 3)])
def test_diag(val_shape, mat_shape):
17
    ctx = F.ctx()
18
    # creation
19
    val = torch.randn(val_shape).to(ctx)
20
21
22
23
24
25
26
27
    mat = diag(val, mat_shape)

    # val, shape attributes
    assert torch.allclose(mat.val, val)
    if mat_shape is None:
        mat_shape = (val_shape[0], val_shape[0])
    assert mat.shape == mat_shape

28
    val = torch.randn(val_shape).to(ctx)
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    # nnz
    assert mat.nnz == val.shape[0]
    # dtype
    assert mat.dtype == val.dtype
    # device
    assert mat.device == val.device

    # as_sparse
    sp_mat = mat.as_sparse()
    # shape
    assert tuple(sp_mat.shape) == mat_shape
    # nnz
    assert sp_mat.nnz == mat.nnz
    # dtype
    assert sp_mat.dtype == mat.dtype
    # device
    assert sp_mat.device == mat.device
    # row, col, val
    edge_index = torch.arange(len(val)).to(mat.device)
49
50
    row, col = sp_mat.coo()
    val = sp_mat.val
51
52
53
54
55
56
57
58
    assert torch.allclose(row, edge_index)
    assert torch.allclose(col, edge_index)
    assert torch.allclose(val, val)


@pytest.mark.parametrize("shape", [(3, 3), (3, 5), (5, 3)])
@pytest.mark.parametrize("d", [None, 2])
def test_identity(shape, d):
59
    ctx = F.ctx()
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    # creation
    mat = identity(shape, d)
    # type
    assert isinstance(mat, DiagMatrix)
    # shape
    assert mat.shape == shape
    # val
    len_val = min(shape)
    if d is None:
        val_shape = len_val
    else:
        val_shape = (len_val, d)
    val = torch.ones(val_shape)
    assert torch.allclose(val, mat.val)
74
75
76
77
78
79
80
81
82
83
84
85
86
87


def test_print():
    ctx = F.ctx()

    # basic
    val = torch.tensor([1.0, 1.0, 2.0]).to(ctx)
    A = diag(val)
    print(A)

    # vector-shape non zero
    val = torch.randn(3, 2).to(ctx)
    A = diag(val)
    print(A)