Commit 8fb5428b authored by rusty1s's avatar rusty1s
Browse files

fixes

parent 78d9af48
......@@ -25,13 +25,16 @@ def sample(src: SparseTensor, num_neighbors: int,
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:
rowptr, col, _ = src.csr()
rowptr, col, value = src.csr()
rowcount = src.storage.rowcount()
rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
rowptr, col, rowcount, subset, num_neighbors, replace)
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=e_id,
if value is not None:
value = value[e_id]
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=value,
sparse_sizes=(subset.size(0), n_id.size(0)),
is_sorted=True)
......
......@@ -409,7 +409,7 @@ class SparseTensor(object):
# Conversions #############################################################
def to_dense(self, options: Optional[torch.Tensor] = None):
def to_dense(self, options: Optional[torch.Tensor] = None) -> torch.Tensor:
row, col, value = self.coo()
if value is not None:
......@@ -541,8 +541,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
scipy.sparse.csc_matrix]
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
csr_matrix, scipy.sparse.csc_matrix]
@torch.jit.ignore
......@@ -600,24 +600,3 @@ def to_scipy(self: SparseTensor, layout: Optional[str] = None,
SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy
# Hacky fixes #################################################################
# Fix standard operators of `torch.Tensor` for PyTorch<=1.3.
# https://github.com/pytorch/pytorch/pull/31769
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR <= 3):
def add(self, other):
if torch.is_tensor(other) or is_scalar(other):
return self.add(other)
return NotImplemented
def mul(self, other):
if torch.is_tensor(other) or is_scalar(other):
return self.mul(other)
return NotImplemented
torch.Tensor.__add__ = add
torch.Tensor.__mul__ = mul
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment