Commit 139cc6c1 authored by rusty1s's avatar rusty1s
Browse files

Merge branch 'master' of github.com:rusty1s/pytorch_sparse

parents d3a94a25 18bd8b67
...@@ -12,7 +12,7 @@ for library in [ ...@@ -12,7 +12,7 @@ for library in [
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
if torch.version.cuda is not None: # pragma: no cover if torch.cuda.is_available() and torch.version.cuda: # pragma: no cover
cuda_version = torch.ops.torch_sparse.cuda_version() cuda_version = torch.ops.torch_sparse.cuda_version()
if cuda_version == -1: if cuda_version == -1:
......
...@@ -35,7 +35,7 @@ def reduction(src: SparseTensor, dim: Optional[int] = None, ...@@ -35,7 +35,7 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
if dim == 0 and value is not None: if dim == 0 and value is not None:
col = src.storage.col() col = src.storage.col()
return scatter(value, col, dim=0, dim_size=src.size(0)) return scatter(value, col, 0, None, src.size(1), reduce)
elif dim == 0 and value is None: elif dim == 0 and value is None:
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
return src.storage.colcount().to(src.dtype()) return src.storage.colcount().to(src.dtype())
......
...@@ -66,7 +66,7 @@ class SparseTensor(object): ...@@ -66,7 +66,7 @@ class SparseTensor(object):
value: Optional[torch.Tensor] = None value: Optional[torch.Tensor] = None
if has_value: if has_value:
value = mat._values() value = mat.values()
return SparseTensor(row=row, rowptr=None, col=col, value=value, return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)), sparse_sizes=(mat.size(0), mat.size(1)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment