test_elementwise_op.py 5.47 KB
Newer Older
1
import operator
2
3
4
import sys

import backend as F
5
6

import dgl.sparse as dglsp
7
8
9
10
11
import pytest
import torch


@pytest.mark.parametrize("val_shape", [(), (2,)])
12
13
14
15
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_coo(val_shape, opname):
    op = getattr(operator, opname)
    func = getattr(dglsp, opname)
16
17
18
19
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
20
    A = dglsp.from_coo(row, col, val)
21
22
23
24

    row = torch.tensor([1, 0]).to(ctx)
    col = torch.tensor([0, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
25
    B = dglsp.from_coo(row, col, val, shape=A.shape)
26

27
28
29
    C1 = op(A, B).to_dense()
    C2 = func(A, B).to_dense()
    dense_C = op(A.to_dense(), B.to_dense())
30

31
32
    assert torch.allclose(dense_C, C1)
    assert torch.allclose(dense_C, C2)
33

34
    with pytest.raises(TypeError):
35
        op(A, 2)
36
    with pytest.raises(TypeError):
37
        op(2, A)
38

39
40

@pytest.mark.parametrize("val_shape", [(), (2,)])
41
42
43
44
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_csr(val_shape, opname):
    op = getattr(operator, opname)
    func = getattr(dglsp, opname)
45
46
47
48
    ctx = F.ctx()
    indptr = torch.tensor([0, 1, 2, 3]).to(ctx)
    indices = torch.tensor([3, 0, 2]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
49
    A = dglsp.from_csr(indptr, indices, val)
50
51
52
53

    indptr = torch.tensor([0, 1, 2, 2]).to(ctx)
    indices = torch.tensor([2, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
54
    B = dglsp.from_csr(indptr, indices, val, shape=A.shape)
55

56
57
58
    C1 = op(A, B).to_dense()
    C2 = func(A, B).to_dense()
    dense_C = op(A.to_dense(), B.to_dense())
59

60
61
    assert torch.allclose(dense_C, C1)
    assert torch.allclose(dense_C, C2)
62

63
    with pytest.raises(TypeError):
64
        op(A, 2)
65
    with pytest.raises(TypeError):
66
        op(2, A)
67

68
69

@pytest.mark.parametrize("val_shape", [(), (2,)])
70
71
72
73
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_csc(val_shape, opname):
    op = getattr(operator, opname)
    func = getattr(dglsp, opname)
74
75
76
77
    ctx = F.ctx()
    indptr = torch.tensor([0, 1, 1, 2, 3]).to(ctx)
    indices = torch.tensor([1, 2, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
78
    A = dglsp.from_csc(indptr, indices, val)
79
80
81
82

    indptr = torch.tensor([0, 1, 1, 2, 2]).to(ctx)
    indices = torch.tensor([1, 0]).to(ctx)
    val = torch.randn(indices.shape + val_shape).to(ctx)
83
    B = dglsp.from_csc(indptr, indices, val, shape=A.shape)
84

85
86
87
    C1 = op(A, B).to_dense()
    C2 = func(A, B).to_dense()
    dense_C = op(A.to_dense(), B.to_dense())
88

89
90
    assert torch.allclose(dense_C, C1)
    assert torch.allclose(dense_C, C2)
91

92
    with pytest.raises(TypeError):
93
        op(A, 2)
94
    with pytest.raises(TypeError):
95
        op(2, A)
96

97
98

@pytest.mark.parametrize("val_shape", [(), (2,)])
99
100
101
102
@pytest.mark.parametrize("opname", ["add", "sub"])
def test_addsub_diag(val_shape, opname):
    op = getattr(operator, opname)
    func = getattr(dglsp, opname)
103
104
105
    ctx = F.ctx()
    shape = (3, 4)
    val_shape = (shape[0],) + val_shape
106
107
    D1 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
    D2 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
108

109
110
111
    C1 = op(D1, D2).to_dense()
    C2 = func(D1, D2).to_dense()
    dense_C = op(D1.to_dense(), D2.to_dense())
112

113
114
115
116
117
118
119
    assert torch.allclose(dense_C, C1)
    assert torch.allclose(dense_C, C2)

    with pytest.raises(TypeError):
        op(D1, 2)
    with pytest.raises(TypeError):
        op(2, D1)
120
121
122
123
124
125
126
127


@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_add_sparse_diag(val_shape):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
128
    A = dglsp.from_coo(row, col, val)
129
130
131

    shape = (3, 4)
    val_shape = (shape[0],) + val_shape
132
    D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)
133

134
135
    sum1 = (A + D).to_dense()
    sum2 = (D + A).to_dense()
136
137
    sum3 = dglsp.add(A, D).to_dense()
    sum4 = dglsp.add(D, A).to_dense()
138
    dense_sum = A.to_dense() + D.to_dense()
139
140
141
142
143

    assert torch.allclose(dense_sum, sum1)
    assert torch.allclose(dense_sum, sum2)
    assert torch.allclose(dense_sum, sum3)
    assert torch.allclose(dense_sum, sum4)
144
145


146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
@pytest.mark.parametrize("val_shape", [(), (2,)])
def test_sub_sparse_diag(val_shape):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape + val_shape).to(ctx)
    A = dglsp.from_coo(row, col, val)

    shape = (3, 4)
    val_shape = (shape[0],) + val_shape
    D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape)

    diff1 = (A - D).to_dense()
    diff2 = (D - A).to_dense()
    diff3 = dglsp.sub(A, D).to_dense()
    diff4 = dglsp.sub(D, A).to_dense()
    dense_diff = A.to_dense() - D.to_dense()

    assert torch.allclose(dense_diff, diff1)
    assert torch.allclose(dense_diff, -diff2)
    assert torch.allclose(dense_diff, diff3)
    assert torch.allclose(dense_diff, -diff4)


170
171
172
173
174
175
@pytest.mark.parametrize("op", ["mul", "truediv", "pow"])
def test_error_op_sparse_diag(op):
    ctx = F.ctx()
    row = torch.tensor([1, 0, 2]).to(ctx)
    col = torch.tensor([0, 3, 2]).to(ctx)
    val = torch.randn(row.shape).to(ctx)
176
    A = dglsp.from_coo(row, col, val)
177
178

    shape = (3, 4)
179
    D = dglsp.diag(torch.randn(row.shape[0]).to(ctx), shape=shape)
180
181
182
183
184

    with pytest.raises(TypeError):
        getattr(operator, op)(A, D)
    with pytest.raises(TypeError):
        getattr(operator, op)(D, A)