Commit 143938b7 authored by rusty1s's avatar rusty1s
Browse files

bugfix

parent 9a608051
import torch
import torch_scatter
from torch_scatter import segment_csr
def reduce(src, dim=None, reduce='add', deterministic=False):
if dim is None and src.has_value():
def reduce(src, dim=None, reduce='add', deterministic=False): if dim is None and src.has_value():
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
return op(src.storage.value)
......@@ -34,7 +31,8 @@ def reduce(src, dim=None, reduce='add', deterministic=False):
op = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = op(value, dim=dense_dims)
# TODO: ARGOUT
if isinstance(value, tuple):
return (src.set_value(value[0], layout='csr'),) + value[1:]
return src.set_value(value, layout='csr')
if len(dense_dims) > 0 and len(sparse_dims) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment