Unverified Commit c341520d authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[bugfix] Fix the behavior of min/max reducer for 1-dim dta. (#2250)

* udp

* add test

* udp

* fix mxnet
parent ee9093f5
......@@ -25,7 +25,10 @@ def _scatter_nd(index, src, n_rows):
offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
stride *= di
if ndim > 1:
new_idx = index * stride + sum(offsets)
else:
new_idx = index
src = src.reshape(-1)
new_idx = new_idx.reshape(-1)
rst = np.zeros((stride * n_rows,), dtype=src.dtype)
......@@ -48,7 +51,10 @@ def _gather_nd(index, src):
offsets.append(
(stride * offset_i).reshape((1,) * i + (di,) + (1,) * (ndim - 1 - i)))
stride *= di
if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx)
else:
new_idx = index
src = src.reshape(-1)
new_idx = new_idx.reshape(-1)
rst = nd.take(src, new_idx).reshape(shp)
......
......@@ -20,7 +20,10 @@ def _scatter_nd(index, src, n_rows):
offsets.append(
tf.reshape((stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)))
stride *= di
if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx)
else:
new_idx = index
src = tf.reshape(src, (-1,))
new_idx = tf.reshape(new_idx, (-1, 1))
rst = tf.reshape(tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:]))
......@@ -39,7 +42,10 @@ def _gather_nd(index, src):
offsets.append(
tf.reshape((stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)))
stride *= di
if ndim > 1:
new_idx = index * stride + copy_to(sum(offsets), ctx)
else:
new_idx = index
src = tf.reshape(src, (-1,))
new_idx = tf.reshape(new_idx, (-1))
rst = tf.reshape(tf.gather(src, new_idx), shp)
......@@ -72,7 +78,7 @@ def _reduce_grad(grad, shape):
reduce_idx = np.asarray(np.nonzero(np.asarray(grad_shape) - np.asarray(in_shape)))
reduce_idx += 1 # skip batch dim
reduce_idx_tensor = tf.constant(tuple(
reduce_idx.flatten().tolist()))
reduce_idx.flatten().tolist()), dtype=tf.int32)
grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True)
return tf.reshape(grad, shape)
......
......@@ -167,6 +167,10 @@ def _gspmm(gidx, op, reduce_op, u, e):
# To deal with scalar node/edge features.
if (expand_u or not use_u) and (expand_e or not use_e):
v = F.squeeze(v, -1)
if expand_u and use_cmp:
arg_u = F.squeeze(arg_u, -1)
if expand_e and use_cmp:
arg_e = F.squeeze(arg_e, -1)
return v, (arg_u, arg_e)
......
......@@ -79,7 +79,8 @@ spmm_shapes = [
((3, 3), (1, 3)),
((1,), (3,)),
((3,), (1,)),
((1,), (1,))
((1,), (1,)),
((), ())
]
sddmm_shapes = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment