"tests/python/vscode:/vscode.git/clone" did not exist on "c09d466082b4f7e5bc2f7509cb90bd200d9ead50"
sparse.py 7.68 KB
Newer Older
1
2
3
4
5
import tensorflow as tf
import numpy as np
from .tensor import tensor, copy_to, context
from ...sparse import _gspmm, _gsddmm

6

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def _scatter_nd(index, src, n_rows):
    assert index.shape == src.shape
    shp = index.shape
    ctx = context(src)
    ndim = index.ndim
    offsets = []
    stride = 1
    for i in reversed(range(1, ndim)):
        di = shp[i]
        offset_i = tf.range(di, dtype=index.dtype)
        offsets.append(
            tf.reshape((stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)))
        stride *= di
    new_idx = index * stride + copy_to(sum(offsets), ctx)
    src = tf.reshape(src, (-1,))
    new_idx = tf.reshape(new_idx, (-1, 1))
    rst = tf.reshape(tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:]))
    return rst

26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def _gather_nd(index, src):
    shp = index.shape
    ctx = context(src)
    ndim = index.ndim
    offsets = []
    stride = 1
    for i in reversed(range(1, ndim)):
        di = shp[i]
        offset_i = tf.range(di, dtype=index.dtype)
        offsets.append(
            tf.reshape((stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)))
        stride *= di
    new_idx = index * stride + copy_to(sum(offsets), ctx)
    src = tf.reshape(src, (-1,))
    new_idx = tf.reshape(new_idx, (-1))
    print(src, new_idx)
    rst = tf.reshape(tf.gather(src, new_idx), shp)
    return rst

46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _reduce_grad(grad, shape):
    """Reduce gradient on the broadcast dimension
    If there is broadcast in forward pass, gradients need to be reduced on
    broadcast dimension. This function checks the input tensor shape and
    gradient shape and perform the reduction.
    Parameters
    ----------
    grad: Tensor
        Gradient tensor
    shape: tuple
        Shape of input tensor
    Returns
    -------
    Tensor
    """
    grad_shape = grad.shape[1:]
    in_shape = shape[1:]
    if in_shape == grad_shape:
        # no need to reduce
        return grad
    num_to_squeeze = len(grad_shape) - len(in_shape)
    # pad inshape
    in_shape = (1,) * num_to_squeeze + in_shape
    reduce_idx = np.asarray(np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape)))
    reduce_idx += 1  # skip batch dim
    reduce_idx_tensor = tf.constant(tuple(
        reduce_idx.flatten().tolist()))
    grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True)
    return tf.reshape(grad, shape)

77

78
79
80
81
82
83
84
85
def _need_reduce_last_dim(ufeat, efeat):
    """Indicates whether to reduce the last dimension on edges
    in the backward pass of spmm,
    if so, use dot instead of mul."""
    ushp = ufeat.shape
    eshp = efeat.shape
    return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1

86

87
88
89
def _muldiv(op, x):
    return 1. / x if op == 'div' else x

90

91
92
93
def _addsub(op, x):
    return -x if op == 'sub' else x

94

95
def gspmm_real(gidx, op, reduce_op, X, Y):
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)

    def grad(dZ):
        dZ = tensor(dZ)
        if op != 'copy_rhs':
            g_rev = gidx.reverse()
            if reduce_op == 'sum':
                if op in ['mul', 'div']:
                    dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
                elif op in ['add', 'sub']:
                    dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, Y)[0]
                elif op == 'copy_lhs':
                    dX = _gspmm(g_rev, 'copy_lhs', 'sum', dZ, None)[0]
            else:
                if op in ['mul', 'div']:
                    dX = _scatter_nd(
                        argX,
                        _muldiv(op, _gather_nd(argY, tf.broadcast_to(Y, (Y.shape[0], *dZ.shape[1:])))) * dZ,
                        X.shape[0])
                elif op in ['add', 'sub', 'copy_lhs']:
                    dX = _scatter_nd(argX, dZ, X.shape[0])
            dX = _reduce_grad(dX, X.shape)
118
119
        else:
            dX = tf.zeros_like(X)
120
121
        if op != 'copy_lhs':
            if reduce_op == 'sum':
122
123
124
                if op == 'mul' and _need_reduce_last_dim(X, Y):
                    dY = _gsddmm(gidx, 'dot', X, dZ)
                elif op in ['mul', 'div']:
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
                    dY = _gsddmm(gidx, 'mul', X, dZ)
                    if op == 'div': dY = -dY / (Y ** 2)
                elif op in ['add', 'sub', 'copy_rhs']:
                    dY = _gsddmm(gidx, 'copy_rhs', X, _addsub(op, dZ))
            else:
                out_shp = (Y.shape[0],) + dZ.shape[1:]
                if op in ['mul',  'div']:
                    dY = _scatter_nd(
                        argY,
                        _gather_nd(argX, tf.broadcast_to(X, (X.shape[0], *dZ.shape[1:]))) * dZ,
                        Y.shape[0])
                    if op == 'div': dY = -dY / (Y ** 2)
                elif op in ['add', 'sub', 'copy_rhs']:
                    dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
            dY = _reduce_grad(dY, Y.shape)
140
141
        else:
            dY = tf.zeros_like(Y)
142
143
144
        return dX, dY
    return out, grad

145

146
def gspmm(gidx, op, reduce_op, X, Y):
147
148
    @tf.custom_gradient
    def _lambda(X, Y):
149
        return gspmm_real(gidx, op, reduce_op, X, Y)
150
151
152
153
    if X is None:
        X = tf.zeros(())
    if Y is None:
        Y = tf.zeros(())
154
155
    return _lambda(X, Y)

156

157
def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)

    def grad(dZ):
        if op != 'copy_rhs':
            if lhs_target in ['u', 'v']:
                _gidx = gidx if lhs_target == 'v' else gidx.reverse()
                if op in ['add', 'sub', 'copy_lhs']:
                    dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
                else:  # mul, div, dot
                    if rhs_target == lhs_target:
                        dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * _muldiv(op, Y)
                    elif rhs_target == 'e':
                        dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * _muldiv(op, Y))[0]
                    else:  # rhs_target = !lhs_target
                        dX = _gspmm(_gidx, 'mul', 'sum', _muldiv(op, Y), dZ)[0]
            else:  # lhs_target == 'e'
                if op in ['add', 'sub', 'copy_lhs']:
                    dX = dZ
                else:  # mul, div, dot
                    dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
            dX = _reduce_grad(dX, X.shape)
179
180
        else:
            dX = tf.zeros_like(X)
181
182
183
184
185
186
187
188
189
190
191
192
        if op != 'copy_lhs':
            if rhs_target in ['u', 'v']:
                _gidx = gidx if rhs_target == 'v' else gidx.reverse()
                if op in ['add', 'sub', 'copy_rhs']:
                    dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
                else:  # mul, div, dot
                    if lhs_target == rhs_target:
                        dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0] * X
                    elif lhs_target == 'e':
                        dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ * X)[0]
                    else:  # rhs_target = !lhs_target
                        dY = _gspmm(_gidx, 'mul', 'sum', X, dZ)[0]
193
194
                    if op == 'div':
                        dY = -dY / (Y ** 2)
195
196
197
198
199
            else:
                if op in ['add', 'sub', 'copy_rhs']:
                    dY = _addsub(op, dZ)
                else:  # mul, div, dot
                    dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
200
201
                    if op == 'div':
                        dY = -dY / (Y ** 2)
202
            dY = _reduce_grad(dY, Y.shape)
203
204
        else:
            dY = tf.zeros_like(Y)
205
206
207
        return dX, dY
    return out, grad

208

209
def gsddmm(gidx, op, X, Y, lhs_target='u', rhs_target='v'):
210
211
    @tf.custom_gradient
    def _lambda(X, Y):
212
        return gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target)
213
214
215
216
    if X is None:
        X = tf.zeros(())
    if Y is None:
        Y = tf.zeros(())
217
    return _lambda(X, Y)