Commit 8fb5428b authored by rusty1s's avatar rusty1s
Browse files

fixes

parent 78d9af48
...@@ -25,13 +25,16 @@ def sample(src: SparseTensor, num_neighbors: int, ...@@ -25,13 +25,16 @@ def sample(src: SparseTensor, num_neighbors: int,
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int, def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]: replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:
rowptr, col, _ = src.csr() rowptr, col, value = src.csr()
rowcount = src.storage.rowcount() rowcount = src.storage.rowcount()
rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj( rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
rowptr, col, rowcount, subset, num_neighbors, replace) rowptr, col, rowcount, subset, num_neighbors, replace)
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=e_id, if value is not None:
value = value[e_id]
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=value,
sparse_sizes=(subset.size(0), n_id.size(0)), sparse_sizes=(subset.size(0), n_id.size(0)),
is_sorted=True) is_sorted=True)
......
...@@ -409,7 +409,7 @@ class SparseTensor(object): ...@@ -409,7 +409,7 @@ class SparseTensor(object):
# Conversions ############################################################# # Conversions #############################################################
def to_dense(self, options: Optional[torch.Tensor] = None): def to_dense(self, options: Optional[torch.Tensor] = None) -> torch.Tensor:
row, col, value = self.coo() row, col, value = self.coo()
if value is not None: if value is not None:
...@@ -541,8 +541,8 @@ SparseTensor.__repr__ = __repr__ ...@@ -541,8 +541,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ########################################################### # Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix, ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
scipy.sparse.csc_matrix] csr_matrix, scipy.sparse.csc_matrix]
@torch.jit.ignore @torch.jit.ignore
...@@ -600,24 +600,3 @@ def to_scipy(self: SparseTensor, layout: Optional[str] = None, ...@@ -600,24 +600,3 @@ def to_scipy(self: SparseTensor, layout: Optional[str] = None,
SparseTensor.from_scipy = from_scipy SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy SparseTensor.to_scipy = to_scipy
# Hacky fixes #################################################################
# Fix standard operators of `torch.Tensor` for PyTorch<=1.3.
# https://github.com/pytorch/pytorch/pull/31769
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR <= 3):
def add(self, other):
if torch.is_tensor(other) or is_scalar(other):
return self.add(other)
return NotImplemented
def mul(self, other):
if torch.is_tensor(other) or is_scalar(other):
return self.mul(other)
return NotImplemented
torch.Tensor.__add__ = add
torch.Tensor.__mul__ = mul
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment