Commit c5d606bf authored by rusty1s's avatar rusty1s
Browse files

support PyTorch 1.8.0

parent a3cd60c0
...@@ -141,16 +141,12 @@ class SparseStorage(object): ...@@ -141,16 +141,12 @@ class SparseStorage(object):
@classmethod @classmethod
def empty(self): def empty(self):
self = SparseStorage.__new__(SparseStorage) row = torch.tensor([], dtype=torch.long)
self._row = None col = torch.tensor([], dtype=torch.long)
self._rowptr = None return SparseStorage(row=row, rowptr=None, col=col, value=None,
self._value = None sparse_sizes=(0, 0), rowcount=None, colptr=None,
self._rowcount = None colcount=None, csr2csc=None, csc2csr=None,
self._colptr = None is_sorted=True)
self._colcount = None
self._csr2csc = None
self._csc2csr = None
return self
def has_row(self) -> bool: def has_row(self) -> bool:
return self._row is not None return self._row is not None
......
...@@ -26,9 +26,15 @@ class SparseTensor(object): ...@@ -26,9 +26,15 @@ class SparseTensor(object):
@classmethod @classmethod
def from_storage(self, storage: SparseStorage): def from_storage(self, storage: SparseStorage):
self = SparseTensor.__new__(SparseTensor) out = SparseTensor(row=storage._row, rowptr=storage._rowptr,
self.storage = storage col=storage._col, value=storage._value,
return self sparse_sizes=storage._sparse_sizes, is_sorted=True)
out.storage._rowcount = storage._rowcount
out.storage._colptr = storage._colptr
out.storage._colcount = storage._colcount
out.storage._csr2csc = storage._csr2csc
out.storage._csc2csr = storage._csc2csr
return out
@classmethod @classmethod
def from_edge_index(self, edge_index: torch.Tensor, def from_edge_index(self, edge_index: torch.Tensor,
...@@ -109,14 +115,14 @@ class SparseTensor(object): ...@@ -109,14 +115,14 @@ class SparseTensor(object):
colcount[M:] = 0 colcount[M:] = 0
csr2csc = csc2csr = row csr2csc = csc2csr = row
storage: SparseStorage = SparseStorage( out = SparseTensor(row=row, rowptr=rowptr, col=col, value=value,
row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N), sparse_sizes=(M, N), is_sorted=True)
rowcount=rowcount, colptr=colptr, colcount=colcount, out.storage._rowcount = rowcount
csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True) out.storage._colptr = colptr
out.storage._colcount = colcount
self = SparseTensor.__new__(SparseTensor) out.storage._csr2csc = csr2csc
self.storage = storage out.storage._csc2csr = csc2csr
return self return out
def copy(self): def copy(self):
return self.from_storage(self.storage) return self.from_storage(self.storage)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment