Commit 3679da2b authored by rusty1s's avatar rusty1s
Browse files

remove warnings filter

parent 8c7560b0
...@@ -7,9 +7,6 @@ from torch_sparse.utils import Final ...@@ -7,9 +7,6 @@ from torch_sparse.utils import Final
layouts: Final[List[str]] = ['coo', 'csr', 'csc'] layouts: Final[List[str]] = ['coo', 'csr', 'csc']
# FIXME: Remove once `/` on `LongTensors` is officially removed from PyTorch.
warnings.filterwarnings("ignore", category=UserWarning)
def get_layout(layout: Optional[str] = None) -> str: def get_layout(layout: Optional[str] = None) -> str:
if layout is None: if layout is None:
...@@ -130,7 +127,9 @@ class SparseStorage(object): ...@@ -130,7 +127,9 @@ class SparseStorage(object):
if not is_sorted: if not is_sorted:
idx = self._col.new_zeros(self._col.numel() + 1) idx = self._col.new_zeros(self._col.numel() + 1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col idx[1:] = self.row()
idx[1:] *= self._sparse_sizes[1]
idx[1:] += self._col
if (idx[1:] < idx[:-1]).any(): if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort() perm = idx[1:].argsort()
self._row = self.row()[perm] self._row = self.row()[perm]
...@@ -238,7 +237,7 @@ class SparseStorage(object): ...@@ -238,7 +237,7 @@ class SparseStorage(object):
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)]) rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None: if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)]) rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
else: elif diff_0 < 0:
if rowptr is not None: if rowptr is not None:
rowptr = rowptr[:-diff_0] rowptr = rowptr[:-diff_0]
if rowcount is not None: if rowcount is not None:
...@@ -251,7 +250,7 @@ class SparseStorage(object): ...@@ -251,7 +250,7 @@ class SparseStorage(object):
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)]) colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None: if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)]) colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
else: elif diff_1 < 0:
if colptr is not None: if colptr is not None:
colptr = colptr[:-diff_1] colptr = colptr[:-diff_1]
if colcount is not None: if colcount is not None:
...@@ -280,7 +279,7 @@ class SparseStorage(object): ...@@ -280,7 +279,7 @@ class SparseStorage(object):
idx = self.sparse_size(1) * self.row() + self.col() idx = self.sparse_size(1) * self.row() + self.col()
row = idx / num_cols row = idx // num_cols
col = idx % num_cols col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long assert row.dtype == torch.long and col.dtype == torch.long
......
...@@ -397,8 +397,8 @@ class SparseTensor(object): ...@@ -397,8 +397,8 @@ class SparseTensor(object):
return mat return mat
def to_torch_sparse_coo_tensor( def to_torch_sparse_coo_tensor(self, dtype: Optional[int] = None
self, dtype: Optional[int] = None) -> torch.Tensor: ) -> torch.Tensor:
row, col, value = self.coo() row, col, value = self.coo()
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
...@@ -503,8 +503,8 @@ SparseTensor.__repr__ = __repr__ ...@@ -503,8 +503,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ########################################################### # Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix, ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
scipy.sparse.csc_matrix] csr_matrix, scipy.sparse.csc_matrix]
@torch.jit.ignore @torch.jit.ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment