Commit 3cc32a97 authored by Lingfan Yu's avatar Lingfan Yu Committed by Minjie Wang
Browse files

[Feature] Reduce messages with scatter_add in PyTorch (#427)

* implement pytorch spmm with gather and scatter add

* fix

* replace torch take with index_select

* comments

* comment about pytorch __getitem__ operator pitfall

* typo
parent fe44ffe5
...@@ -135,15 +135,17 @@ def zeros_like(input): ...@@ -135,15 +135,17 @@ def zeros_like(input):
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx) return th.ones(shape, dtype=dtype, device=ctx)
if TH_VERSION.version[0] == 0: def spmm(x, y):
# TODO(minjie): note this does not support autograd on the `x` tensor. dst, src = x._indices()
# should adopt a workaround using custom op. # scatter index
def spmm(x, y): index = dst.view(-1, 1).expand(-1, y.shape[1])
return th.spmm(x, y) # zero tensor to be scatter_add to
else: out = y.new_full((x.shape[0], y.shape[1]), 0)
# torch v1.0+ # look up src features and multiply by edge features
def spmm(x, y): # Note: using y[src] instead of index_select will lead to terrible
return th.sparse.mm(x, y) # performance in backward
feature = th.index_select(y, 0, src) * x._values().unsqueeze(-1)
return out.scatter_add(0, index, feature)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input) y = th.zeros(n_segs, *input.shape[1:]).to(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment