test_elementwise_op.py 4.14 KB
Newer Older
1
import operator
2
3
4
5
6
7
import sys

import backend as F
import pytest
import torch

8
from dgl.sparse import add, diag, from_coo, from_csc, from_csr
9
10
11
12
13
14
15
16
17
18
19
20

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_coo(val_shape):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
21
    A = from_coo(row, col, val)
22
23
24
25

    row = torch.tensor([1, 0]).to(ctx)
    col = torch.tensor([0, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
26
    B = from_coo(row, col, val, shape=A.shape)
27
28
29
30
31
32
33
34

    sum1 = (A + B).dense()
    sum2 = add(A, B).dense()
    dense_sum = A.dense() + B.dense()

    assert torch.allclose(dense_sum, sum1)
    assert torch.allclose(dense_sum, sum2)

35
36
37
38
39
    with pytest.raises(TypeError):
        A + 2
    with pytest.raises(TypeError):
        2 + A

40
41
42
43
44
45
46

@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csr(val_shape):
    ctx = F.ctx()
    indptr = torch.tensor([0, 1, 2, 3]).to(ctx)
    indices = torch.tensor([3, 0, 2]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
47
    A = from_csr(indptr, indices, val)
48
49
50
51

    indptr = torch.tensor([0, 1, 2, 2]).to(ctx)
    indices = torch.tensor([2, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
52
    B = from_csr(indptr, indices, val, shape=A.shape)
53
54
55
56
57
58
59
60

    sum1 = (A + B).dense()
    sum2 = add(A, B).dense()
    dense_sum = A.dense() + B.dense()

    assert torch.allclose(dense_sum, sum1)
    assert torch.allclose(dense_sum, sum2)

61
62
63
64
65
    with pytest.raises(TypeError):
        A + 2
    with pytest.raises(TypeError):
        2 + A

66
67
68
69
70
71
72

@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_csc(val_shape):
    ctx = F.ctx()
    indptr = torch.tensor([0, 1, 1, 2, 3]).to(ctx)
    indices = torch.tensor([1, 2, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
73
    A = from_csc(indptr, indices, val)
74
75
76
77

    indptr = torch.tensor([0, 1, 1, 2, 2]).to(ctx)
    indices = torch.tensor([1, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
78
    B = from_csc(indptr, indices, val, shape=A.shape)
79
80
81
82
83
84
85
86

    sum1 = (A + B).dense()
    sum2 = add(A, B).dense()
    dense_sum = A.dense() + B.dense()

    assert torch.allclose(dense_sum, sum1)
    assert torch.allclose(dense_sum, sum2)

87
88
89
90
91
    with pytest.raises(TypeError):
        A + 2
    with pytest.raises(TypeError):
        2 + A

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_diag(val_shape):
    ctx = F.ctx()
    shape = (3, 4)
    val_shape = (shape[0],) + val_shape
    D1 = diag(torch.randn(val_shape).to(ctx), shape=shape)
    D2 = diag(torch.randn(val_shape).to(ctx), shape=shape)

    sum1 = (D1 + D2).dense()
    sum2 = add(D1, D2).dense()
    dense_sum = D1.dense() + D2.dense()

    assert torch.allclose(dense_sum, sum1)
    assert torch.allclose(dense_sum, sum2)


@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_sparse_diag(val_shape):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
115
    A = from_coo(row, col, val)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

    shape = (3, 4)
    val_shape = (shape[0],) + val_shape
    D = diag(torch.randn(val_shape).to(ctx), shape=shape)

    sum1 = (A + D).dense()
    sum2 = (D + A).dense()
    sum3 = add(A, D).dense()
    sum4 = add(D, A).dense()
    dense_sum = A.dense() + D.dense()

    assert torch.allclose(dense_sum, sum1)
    assert torch.allclose(dense_sum, sum2)
    assert torch.allclose(dense_sum, sum3)
    assert torch.allclose(dense_sum, sum4)
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147


@pytest.mark.parametrize("op", ["mul", "truediv", "pow"])
def test_error_op_sparse_diag(op):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape).to(ctx)
    A = from_coo(row, col, val)

    shape = (3, 4)
    D = diag(torch.randn(row.shape[0]).to(ctx), shape=shape)

    with pytest.raises(TypeError):
        getattr(operator, op)(A, D)
    with pytest.raises(TypeError):
        getattr(operator, op)(D, A)