test_elementwise_op_diag.py 2.03 KB
Newer Older
1
import operator
2
import sys
3

4
import backend as F
5
6
import pytest
import torch
7

8
from dgl.sparse import diag, power
9

Mufei Li's avatar
Mufei Li committed
10
# TODO(#4818): Skipping tests on win.
11
12
13
14
15
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)


def all_close_sparse(A, B):
16
    assert torch.allclose(torch.stack(A.coo()), torch.stack(B.coo()))
17
18
19
20
    assert torch.allclose(A.values(), B.values())
    assert A.shape == B.shape


21
22
23
@pytest.mark.parametrize("opname", ["add", "sub", "mul", "truediv"])
def test_diag_op_diag(opname):
    op = getattr(operator, opname)
24
25
26
27
28
29
30
    ctx = F.ctx()
    shape = (3, 4)
    D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
    D2 = diag(torch.arange(10, 13).to(ctx), shape=shape)
    result = op(D1, D2)
    assert torch.allclose(result.val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)
    assert result.shape == D1.shape
31
32


33
34
35
@pytest.mark.parametrize(
    "v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
36
def test_diag_op_scalar(v_scalar):
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    ctx = F.ctx()
    shape = (3, 4)
    D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)

    # D * v
    D2 = D1 * v_scalar
    assert torch.allclose(D1.val * v_scalar, D2.val, rtol=1e-4, atol=1e-4)
    assert D1.shape == D2.shape

    # v * D
    D2 = v_scalar * D1
    assert torch.allclose(v_scalar * D1.val, D2.val, rtol=1e-4, atol=1e-4)
    assert D1.shape == D2.shape

    # D / v
    D2 = D1 / v_scalar
    assert torch.allclose(D1.val / v_scalar, D2.val, rtol=1e-4, atol=1e-4)
    assert D1.shape == D2.shape

    # D ^ v
    D1 = diag(torch.arange(1, 4).to(ctx))
58
59
60
61
62
63
64
    D2 = D1**v_scalar
    assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
    assert D1.shape == D2.shape

    # pow(D, v)
    D2 = power(D1, v_scalar)
    assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
65
    assert D1.shape == D2.shape
66
67
68
69
70
71
72
73
74
75

    with pytest.raises(TypeError):
        D1 + v_scalar
    with pytest.raises(TypeError):
        v_scalar + D1

    with pytest.raises(TypeError):
        D1 - v_scalar
    with pytest.raises(TypeError):
        v_scalar - D1