test_elementwise_op.py 5.61 KB
Newer Older
1
import operator
2
3
4
import sys

import backend as F
5
6

import dgl.sparse as dglsp
7
8
9
10
11
12
13
14
15
import pytest
import torch

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)


@pytest.mark.parametrize("val_shape", [(), (2,)])
16
17
18
19
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_coo(val_shape, opname):
    op = getattr(operator, opname)
    func = getattr(dglsp, opname)
20
21
22
23
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
24
    A = dglsp.from_coo(row, col, val)
25
26
27
28

    row = torch.tensor([1, 0]).to(ctx)
    col = torch.tensor([0, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
29
    B = dglsp.from_coo(row, col, val, shape=A.shape)
30

31
32
33
    C1 = op(A, B).to_dense()
    C2 = func(A, B).to_dense()
    dense_C = op(A.to_dense(), B.to_dense())
34

35
36
    assert torch.allclose(dense_C, C1)
    assert torch.allclose(dense_C, C2)
37

38
    with pytest.raises(TypeError):
39
        op(A, 2)
40
    with pytest.raises(TypeError):
41
        op(2, A)
42

43
44

@pytest.mark.parametrize("val_shape", [(), (2,)])
45
46
47
48
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_csr(val_shape, opname):
    op = getattr(operator, opname)
    func = getattr(dglsp, opname)
49
50
51
52
    ctx = F.ctx()
    indptr = torch.tensor([0, 1, 2, 3]).to(ctx)
    indices = torch.tensor([3, 0, 2]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
53
    A = dglsp.from_csr(indptr, indices, val)
54
55
56
57

    indptr = torch.tensor([0, 1, 2, 2]).to(ctx)
    indices = torch.tensor([2, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
58
    B = dglsp.from_csr(indptr, indices, val, shape=A.shape)
59

60
61
62
    C1 = op(A, B).to_dense()
    C2 = func(A, B).to_dense()
    dense_C = op(A.to_dense(), B.to_dense())
63

64
65
    assert torch.allclose(dense_C, C1)
    assert torch.allclose(dense_C, C2)
66

67
    with pytest.raises(TypeError):
68
        op(A, 2)
69
    with pytest.raises(TypeError):
70
        op(2, A)
71

72
73

@pytest.mark.parametrize("val_shape", [(), (2,)])
74
75
76
77
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_csc(val_shape, opname):
    op = getattr(operator, opname)
    func = getattr(dglsp, opname)
78
79
80
81
    ctx = F.ctx()
    indptr = torch.tensor([0, 1, 1, 2, 3]).to(ctx)
    indices = torch.tensor([1, 2, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
82
    A = dglsp.from_csc(indptr, indices, val)
83
84
85
86

    indptr = torch.tensor([0, 1, 1, 2, 2]).to(ctx)
    indices = torch.tensor([1, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
87
    B = dglsp.from_csc(indptr, indices, val, shape=A.shape)
88

89
90
91
    C1 = op(A, B).to_dense()
    C2 = func(A, B).to_dense()
    dense_C = op(A.to_dense(), B.to_dense())
92

93
94
    assert torch.allclose(dense_C, C1)
    assert torch.allclose(dense_C, C2)
95

96
    with pytest.raises(TypeError):
97
        op(A, 2)
98
    with pytest.raises(TypeError):
99
        op(2, A)
100

101
102

@pytest.mark.parametrize("val_shape", [(), (2,)])
103
104
105
106
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_diag(val_shape, opname):
    op = getattr(operator, opname)
    func = getattr(dglsp, opname)
107
108
109
    ctx = F.ctx()
    shape = (3, 4)
    val_shape = (shape[0],) + val_shape
110
111
    D1 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
    D2 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
112

113
114
115
    C1 = op(D1, D2).to_dense()
    C2 = func(D1, D2).to_dense()
    dense_C = op(D1.to_dense(), D2.to_dense())
116

117
118
119
120
121
122
123
    assert torch.allclose(dense_C, C1)
    assert torch.allclose(dense_C, C2)

    with pytest.raises(TypeError):
        op(D1, 2)
    with pytest.raises(TypeError):
        op(2, D1)
124
125
126
127
128
129
130
131


@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_sparse_diag(val_shape):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
132
    A = dglsp.from_coo(row, col, val)
133
134
135

    shape = (3, 4)
    val_shape = (shape[0],) + val_shape
136
    D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
137

138
139
    sum1 = (A + D).to_dense()
    sum2 = (D + A).to_dense()
140
141
    sum3 = dglsp.add(A, D).to_dense()
    sum4 = dglsp.add(D, A).to_dense()
142
    dense_sum = A.to_dense() + D.to_dense()
143
144
145
146
147

    assert torch.allclose(dense_sum, sum1)
    assert torch.allclose(dense_sum, sum2)
    assert torch.allclose(dense_sum, sum3)
    assert torch.allclose(dense_sum, sum4)
148
149


150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_sub_sparse_diag(val_shape):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
    A = dglsp.from_coo(row, col, val)

    shape = (3, 4)
    val_shape = (shape[0],) + val_shape
    D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)

    diff1 = (A - D).to_dense()
    diff2 = (D - A).to_dense()
    diff3 = dglsp.sub(A, D).to_dense()
    diff4 = dglsp.sub(D, A).to_dense()
    dense_diff = A.to_dense() - D.to_dense()

    assert torch.allclose(dense_diff, diff1)
    assert torch.allclose(dense_diff, -diff2)
    assert torch.allclose(dense_diff, diff3)
    assert torch.allclose(dense_diff, -diff4)


174
175
176
177
178
179
@pytest.mark.parametrize("op", ["mul", "truediv", "pow"])
def test_error_op_sparse_diag(op):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape).to(ctx)
180
    A = dglsp.from_coo(row, col, val)
181
182

    shape = (3, 4)
183
    D = dglsp.diag(torch.randn(row.shape[0]).to(ctx), shape=shape)
184
185
186
187
188

    with pytest.raises(TypeError):
        getattr(operator, op)(A, D)
    with pytest.raises(TypeError):
        getattr(operator, op)(D, A)