Commit 195053ef authored by rusty1s's avatar rusty1s
Browse files

transpose

parent 76bf1e8a
...@@ -36,6 +36,7 @@ class SparseStorage(object): ...@@ -36,6 +36,7 @@ class SparseStorage(object):
assert index.dtype == torch.long assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2 assert index.dim() == 2 and index.size(0) == 2
index = index.contiguous()
if value is not None: if value is not None:
assert value.device == index.device assert value.device == index.device
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import scipy.sparse import scipy.sparse
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.transpose import t
class SparseTensor(object): class SparseTensor(object):
...@@ -128,9 +129,9 @@ class SparseTensor(object): ...@@ -128,9 +129,9 @@ class SparseTensor(object):
rowptr, col, val1 = self.csr() rowptr, col, val1 = self.csr()
colptr, row, val2 = self.csc() colptr, row, val2 = self.csc()
index_symmetric = (rowptr == colptr).all() and (col == row).all() index_sym = (rowptr == colptr).all() and (col == row).all()
value_symmetric = (val1 == val2).all() if self.has_value() else True value_sym = (val1 == val2).all().item() if self.has_value() else True
return index_symmetric and value_symmetric return index_sym.item() and value_sym
def detach_(self): def detach_(self):
self._storage.apply_(lambda x: x.detach_()) self._storage.apply_(lambda x: x.detach_())
...@@ -310,23 +311,11 @@ class SparseTensor(object): ...@@ -310,23 +311,11 @@ class SparseTensor(object):
# Bindings #################################################################### # Bindings ####################################################################
SparseTensor.t = t
# def set_diag(self, value): # def set_diag(self, value):
# raise NotImplementedError # raise NotImplementedError
# def t(self):
# storage = SparseStorage(
# self._col[self._arg_csr_to_csc],
# self._row[self._arg_csr_to_csc],
# self._value[self._arg_csr_to_csc] if self.has_value else None,
# self.sparse_size()[::-1],
# self._colptr,
# self._rowptr,
# self._arg_csc_to_csr,
# self._arg_csr_to_csc,
# is_sorted=True,
# )
# return self.__class__.from_storage(storage)
#
# def masked_select(self, mask): # def masked_select(self, mask):
# raise NotImplementedError # raise NotImplementedError
...@@ -446,10 +435,17 @@ if __name__ == '__main__': ...@@ -446,10 +435,17 @@ if __name__ == '__main__':
data = dataset[0].to(device) data = dataset[0].to(device)
value = torch.ones((data.num_edges, ), device=device) value = torch.ones((data.num_edges, ), device=device)
value = None
mat1 = SparseTensor(data.edge_index, value) mat1 = SparseTensor(data.edge_index, value)
print(mat1) print(mat1)
print(id(mat1))
mat1 = mat1.long()
print(id(mat1))
mat1 = mat1.long()
print(id(mat1))
mat1 = mat1.to(torch.bool)
print(mat1)
print(mat1.is_pinned())
print(mat1.to_dense().size()) print(mat1.to_dense().size())
...@@ -459,9 +455,15 @@ if __name__ == '__main__': ...@@ -459,9 +455,15 @@ if __name__ == '__main__':
print(mat1.to_scipy(layout='coo').todense().shape) print(mat1.to_scipy(layout='coo').todense().shape)
print(mat1.to_scipy(layout='csr').todense().shape) print(mat1.to_scipy(layout='csr').todense().shape)
print(mat1.to_scipy(layout='csc').todense().shape) print(mat1.to_scipy(layout='csc').todense().shape)
print(mat1.is_quadratic())
print(mat1.is_symmetric())
mat1 = mat1.t()
print(mat1)
# mat1 = mat1.t() # mat1 = mat1.t()
# mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges), # mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
# device=device) # device=device)
# mat2 = mat2.coalesce() # mat2 = mat2.coalesce()
# mat2 = mat2.t().coalesce() # mat2 = mat2.t().coalesce()
......
...@@ -26,3 +26,17 @@ def transpose(index, value, m, n, coalesced=True): ...@@ -26,3 +26,17 @@ def transpose(index, value, m, n, coalesced=True):
if coalesced: if coalesced:
index, value = coalesce(index, value, n, m) index, value = coalesce(index, value, n, m)
return index, value return index, value
def t(mat):
((row, col), value), perm = mat.coo(), mat._storage.csr_to_csc
storage = mat._storage.__class__(
index=torch.stack([col, row], dim=0)[:, perm],
value=value[perm] if mat.has_value() else None,
sparse_size=mat.sparse_size()[::-1],
rowptr=mat._storage._colptr,
colptr=mat._storage._rowptr,
csr_to_csc=mat._storage._csc_to_csr,
csc_to_csr=perm,
is_sorted=True)
return mat.__class__.from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment