Unverified Commit fb4a0508 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[bugfix] Solve the boundary issue in backward function of segment sum (#2610)

* upd

* trigger

* upd
parent b04cf65b
...@@ -346,10 +346,10 @@ class SegmentReduce(mx.autograd.Function): ...@@ -346,10 +346,10 @@ class SegmentReduce(mx.autograd.Function):
offsets = self.offsets offsets = self.offsets
m = offsets[-1].asscalar() m = offsets[-1].asscalar()
if self.op == 'sum': if self.op == 'sum':
offsets_np = asnumpy(offsets[1:-1]) offsets_np = asnumpy(offsets[1:])
indices_np = np.zeros((m,), dtype=offsets_np.dtype) indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np)) np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
indices_np = np.cumsum(indices_np, -1) indices_np = np.cumsum(indices_np, -1)[:-1]
indices = zerocopy_from_numpy(indices_np) indices = zerocopy_from_numpy(indices_np)
dx = dy[indices] dx = dy[indices]
else: else:
......
...@@ -275,11 +275,13 @@ class SegmentReduce(th.autograd.Function): ...@@ -275,11 +275,13 @@ class SegmentReduce(th.autograd.Function):
arg, offsets = ctx.saved_tensors arg, offsets = ctx.saved_tensors
m = offsets[-1].item() m = offsets[-1].item()
if op == 'sum': if op == 'sum':
offsets = offsets[1:-1] offsets = offsets[1:]
# To address the issue of trailing zeros, related issue:
# https://github.com/dmlc/dgl/pull/2610
indices = th.zeros( indices = th.zeros(
(m,), device=offsets.device, dtype=offsets.dtype) (m + 1,), device=offsets.device, dtype=offsets.dtype)
indices.scatter_add_(0, offsets, th.ones_like(offsets)) indices.scatter_add_(0, offsets, th.ones_like(offsets))
indices = th.cumsum(indices, -1) indices = th.cumsum(indices, -1)[:-1]
dx = dy[indices] dx = dy[indices]
else: else:
dx = _bwd_segment_cmp(dy, arg, m) dx = _bwd_segment_cmp(dy, arg, m)
......
...@@ -261,10 +261,10 @@ def segment_reduce_real(op, x, offsets): ...@@ -261,10 +261,10 @@ def segment_reduce_real(op, x, offsets):
def segment_reduce_backward(dy): def segment_reduce_backward(dy):
m = x.shape[0] m = x.shape[0]
if op == 'sum': if op == 'sum':
offsets_np = asnumpy(offsets[1:-1]) offsets_np = asnumpy(offsets[1:])
indices_np = np.zeros((m,), dtype=offsets_np.dtype) indices_np = np.zeros((m + 1,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np)) np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
indices_np = np.cumsum(indices_np, -1) indices_np = np.cumsum(indices_np, -1)[:-1]
indices = zerocopy_from_numpy(indices_np) indices = zerocopy_from_numpy(indices_np)
dx = tf.gather(dy, indices) dx = tf.gather(dy, indices)
else: else:
......
...@@ -261,7 +261,7 @@ def test_segment_reduce(reducer): ...@@ -261,7 +261,7 @@ def test_segment_reduce(reducer):
value = F.tensor(np.random.rand(10, 5)) value = F.tensor(np.random.rand(10, 5))
v1 = F.attach_grad(F.clone(value)) v1 = F.attach_grad(F.clone(value))
v2 = F.attach_grad(F.clone(value)) v2 = F.attach_grad(F.clone(value))
seglen = F.tensor([2, 3, 0, 4, 1]) seglen = F.tensor([2, 3, 0, 4, 1, 0, 0])
u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx) u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
v = F.repeat(F.copy_to(F.arange(0, len(seglen), F.int32), ctx), v = F.repeat(F.copy_to(F.arange(0, len(seglen), F.int32), ctx),
seglen, dim=0) seglen, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment