sparse.py 8.96 KB
Newer Older
1
2
3
4
5
import mxnet as mx
import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm
from ...base import dgl_warning
6
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
7

8

9
def _scatter_nd(index, src, n_rows):
10
    """Similar to PyTorch's scatter nd on first dimension."""
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    assert index.shape == src.shape
    dgl_warning("MXNet do not support scatter_add, fallback to numpy.")
    ctx = context(src)
    index = asnumpy(index)
    src = asnumpy(src)
    shp = index.shape
    ndim = src.ndim
    offsets = []
    stride = 1
    for i in reversed(range(1, ndim)):
        di = shp[i]
        offset_i = np.arange(di, dtype=index.dtype)
        offsets.append(
            (stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
        stride *= di
    new_idx = index * stride + sum(offsets)
    src = src.reshape(-1)
    new_idx = new_idx.reshape(-1)
    rst = np.zeros((stride * n_rows,), dtype=src.dtype)
    np.add.at(rst, new_idx, src)
    rst = rst.reshape(n_rows, *shp[1:])
    rst = copy_to(zerocopy_from_numpy(rst), ctx)
    return rst

35

36
def _gather_nd(index, src):
37
    """Similar to PyTorch's gather nd on first dimension."""
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    ctx = context(src)
    shp = index.shape
    ndim = src.ndim
    offsets = []
    stride = 1
    for i in reversed(range(1, ndim)):
        di = shp[i]
        offset_i = nd.arange(di, dtype=index.dtype)
        offsets.append(
            (stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
        stride *= di
    new_idx = index * stride + copy_to(sum(offsets), ctx)
    src = src.reshape(-1)
    new_idx = new_idx.reshape(-1)
    rst = nd.take(src, new_idx).reshape(shp)
    return rst

55

56
57
58
59
60
def _reduce_grad(grad, shape):
    """Reduce gradient on the broadcast dimension
    If there is broadcast in forward pass, gradients need to be reduced on
    broadcast dimension. This function checks the input tensor shape and
    gradient shape and perform the reduction.
61

62
63
64
65
66
67
    Parameters
    ----------
    grad: Tensor
        Gradient tensor
    shape: tuple
        Shape of input tensor
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    Returns
    -------
    Tensor
    """
    grad_shape = grad.shape[1:]
    in_shape = shape[1:]
    if in_shape == grad_shape:
        # no need to reduce
        return grad
    num_to_squeeze = len(grad_shape) - len(in_shape)
    # pad inshape
    in_shape = (1,) * num_to_squeeze + in_shape
    # pad in_shape
    in_shape = (1,) * num_to_squeeze + in_shape
    reduce_idx = np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape))[0]
    reduce_idx += 1  # skip batch dim
    grad = grad.sum(axis=tuple(reduce_idx), keepdims=True)
    return grad.reshape(shape)

88

89
90
91
92
93
94
95
96
def _need_reduce_last_dim(ufeat, efeat):
    """Indicates whether to reduce the last dimension on edges
    in the backward pass of spmm,
    if so, use dot instead of mul."""
    ushp = ufeat.shape
    eshp = efeat.shape
    return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1

97

98
99
100
def _muldiv(op, x):
    return 1. / x if op == 'div' else x

101

102
103
104
def _addsub(op, x):
    return -x if op == 'sub' else x

105

106
class GSpMM(mx.autograd.Function):
107
    def __init__(self, gidx, op, reduce_op):
108
        super(GSpMM, self).__init__()
109
        self.gidx = gidx
110
111
112
113
114
115
116
117
118
119
120
121
        self.op = op
        self.reduce_op = reduce_op

    def forward(self, X, Y):
        out, (argX, argY) = _gspmm(self.gidx, self.op, self.reduce_op, X, Y)
        self.save_for_backward(X, Y, argX, argY)
        return out

    def backward(self, dZ):
        ctx = context(dZ)
        X, Y, argX, argY = self.saved_tensors
        gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
122
        if op != 'copy_rhs':
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
            g_rev = gidx.reverse()
            if reduce_op == 'sum':
                if op in ['mul', 'div']:
                    dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
                elif op in ['add', 'sub']:
                    dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0]
                elif op == 'copy_lhs':
                    dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0]
            else:
                if op in ['mul', 'div']:
                    dX = _scatter_nd(
                        argX,
                        _muldiv(op, _gather_nd(argY, Y.broadcast_to((Y.shape[0], *dZ.shape[1:])))) * dZ,
                        X.shape[0])
                elif op in ['add', 'sub', 'copy_lhs']:
                    dX = _scatter_nd(argX, dZ, X.shape[0])
            dX = _reduce_grad(dX, X.shape)
140
141
142
        else:
            dX = nd.zeros_like(X)
        if op != 'copy_lhs':
143
            if reduce_op == 'sum':
144
145
146
                if op == 'mul' and _need_reduce_last_dim(X, Y):
                    dY = _gsddmm(gidx, 'dot', X, dZ)
                elif op in ['mul', 'div']:
147
                    dY = _gsddmm(gidx, 'mul', X, dZ)
148
149
                    if op == 'div':
                        dY = -dY / (Y ** 2)
150
151
152
153
154
155
156
157
                elif op in ['add', 'sub', 'copy_rhs']:
                    dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
            else:
                if op in ['mul',  'div']:
                    dY = _scatter_nd(
                        argY,
                        _gather_nd(argX, X.broadcast_to((X.shape[0], *dZ.shape[1:]))) * dZ,
                        Y.shape[0])
158
159
                    if op == 'div':
                        dY = -dY / (Y ** 2)
160
161
162
                elif op in ['add', 'sub', 'copy_rhs']:
                    dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
            dY = _reduce_grad(dY, Y.shape)
163
164
        else:
            dY = nd.zeros_like(Y)
165
166
167
        self.saved_tensors = None
        return dX, dY

168

169
170
171
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
    func = GSpMM(gidx, op, reduce_op)
    ctx = to_backend_ctx(gidx.ctx)
172
    if lhs_data is None:
173
        lhs_data = nd.zeros((1,), ctx=ctx)
174
    if rhs_data is None:
175
        rhs_data = nd.zeros((1,), ctx=ctx)
176
177
    return func(lhs_data, rhs_data)

178

179
class GSDDMM(mx.autograd.Function):
180
    def __init__(self, gidx, op, lhs_target, rhs_target):
181
        super(GSDDMM, self).__init__()
182
        self.gidx = gidx
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        self.op = op
        self.lhs_target = lhs_target
        self.rhs_target = rhs_target

    def forward(self, X, Y):
        out = _gsddmm(self.gidx, self.op, X, Y, self.lhs_target, self.rhs_target)
        self.save_for_backward(X, Y)
        return out

    def backward(self, dZ):
        ctx = context(dZ)
        X, Y = self.saved_tensors
        gidx, op = self.gidx, self.op
        lhs_target, rhs_target = self.lhs_target, self.rhs_target
        if op != 'copy_rhs':
            if lhs_target in ['u', 'v']:
                _gidx = gidx if self.lhs_target == 'v' else gidx.reverse()
                if op in ['add', 'sub', 'copy_lhs']:
                    dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
                else:  # mul, div, dot
                    if rhs_target == lhs_target:
                        dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y)
                    elif self.rhs_target == 'e':
                        dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0]
                    else:  # rhs_target = !lhs_target
                        dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0]
            else:  # lhs_target == 'e'
                if op in ['add', 'sub', 'copy_lhs']:
                    dX = dZ
                else:  # mul, div, dot
                    dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
            dX = _reduce_grad(dX, X.shape)
215
216
        else:
            dX = nd.zeros_like(X)
217
218
219
220
221
222
223
224
225
226
227
228
        if op != 'copy_lhs':
            if self.rhs_target in ['u', 'v']:
                _gidx = gidx if rhs_target == 'v' else gidx.reverse()
                if op in ['add', 'sub', 'copy_rhs']:
                    dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
                else:  # mul, div, dot
                    if lhs_target == rhs_target:
                        dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X
                    elif self.lhs_target == 'e':
                        dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
                    else:  # rhs_target = !lhs_target
                        dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
229
230
                    if op == 'div':
                        dY = -dY / (Y ** 2)
231
232
233
234
235
            else:
                if op in ['add', 'sub', 'copy_rhs']:
                    dY = _addsub(op, dZ)
                else:  # mul, div, dot
                    dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
236
237
                    if op == 'div':
                        dY = -dY / (Y ** 2)
238
            dY = _reduce_grad(dY, Y.shape)
239
240
        else:
            dY = nd.zeros_like(Y)
241
242
243
        self.saved_tensors = None
        return dX, dY

244

245
246
247
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
    func = GSDDMM(gidx, op, lhs_target, rhs_target)
    ctx = to_backend_ctx(gidx.ctx)
248
    if lhs_data is None:
249
        lhs_data = nd.zeros((1,), ctx=ctx)
250
    if rhs_data is None:
251
        rhs_data = nd.zeros((1,), ctx=ctx)
252
    return func(lhs_data, rhs_data)