test_elementwise_op_diag.py 1.73 KB
Newer Older
1
import operator
2
import sys
3

4
import backend as F
5
6
import pytest
import torch
7

8
from dgl.sparse import diag, power
9

Mufei Li's avatar
Mufei Li committed
10
# TODO(#4818): Skipping tests on win.
11
12
13
14
15
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)


def all_close_sparse(A, B):
16
    assert torch.allclose(torch.stack(A.coo()), torch.stack(B.coo()))
17
18
19
20
    assert torch.allclose(A.values(), B.values())
    assert A.shape == B.shape


21
@pytest.mark.parametrize("op", [operator.sub, operator.mul, operator.truediv])
22
def test_diag_op_diag(op):
23
24
25
26
27
28
29
    ctx = F.ctx()
    shape = (3, 4)
    D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)
    D2 = diag(torch.arange(10, 13).to(ctx), shape=shape)
    result = op(D1, D2)
    assert torch.allclose(result.val, op(D1.val, D2.val), rtol=1e-4, atol=1e-4)
    assert result.shape == D1.shape
30
31
32
33


@pytest.mark.parametrize("v_scalar", [2, 2.5])
def test_diag_op_scalar(v_scalar):
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    ctx = F.ctx()
    shape = (3, 4)
    D1 = diag(torch.arange(1, 4).to(ctx), shape=shape)

    # D * v
    D2 = D1 * v_scalar
    assert torch.allclose(D1.val * v_scalar, D2.val, rtol=1e-4, atol=1e-4)
    assert D1.shape == D2.shape

    # v * D
    D2 = v_scalar * D1
    assert torch.allclose(v_scalar * D1.val, D2.val, rtol=1e-4, atol=1e-4)
    assert D1.shape == D2.shape

    # D / v
    D2 = D1 / v_scalar
    assert torch.allclose(D1.val / v_scalar, D2.val, rtol=1e-4, atol=1e-4)
    assert D1.shape == D2.shape

    # D ^ v
    D1 = diag(torch.arange(1, 4).to(ctx))
55
56
57
58
59
60
61
    D2 = D1**v_scalar
    assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
    assert D1.shape == D2.shape

    # pow(D, v)
    D2 = power(D1, v_scalar)
    assert torch.allclose(D1.val**v_scalar, D2.val, rtol=1e-4, atol=1e-4)
62
    assert D1.shape == D2.shape