Commit 41ca46ba authored by rusty1s's avatar rusty1s
Browse files

bugfix in _parse_to [ci-deploy]

parent 9911c25f
...@@ -12,25 +12,17 @@ from torch_sparse.utils import is_scalar ...@@ -12,25 +12,17 @@ from torch_sparse.utils import is_scalar
class SparseTensor(object): class SparseTensor(object):
storage: SparseStorage storage: SparseStorage
def __init__(self, def __init__(self, row: Optional[torch.Tensor] = None,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[int, int]] = None, sparse_sizes: Optional[Tuple[int, int]] = None,
is_sorted: bool = False): is_sorted: bool = False):
self.storage = SparseStorage( self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
row=row, value=value, sparse_sizes=sparse_sizes,
rowptr=rowptr, rowcount=None, colptr=None, colcount=None,
col=col, csr2csc=None, csc2csr=None,
value=value, is_sorted=is_sorted)
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=is_sorted)
@classmethod @classmethod
def from_storage(self, storage: SparseStorage): def from_storage(self, storage: SparseStorage):
...@@ -53,17 +45,12 @@ class SparseTensor(object): ...@@ -53,17 +45,12 @@ class SparseTensor(object):
if has_value: if has_value:
value = mat[row, col] value = mat[row, col]
return SparseTensor( return SparseTensor(row=row, rowptr=None, col=col, value=value,
row=row, sparse_sizes=(mat.size(0), mat.size(1)),
rowptr=None, is_sorted=True)
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
@classmethod @classmethod
def from_torch_sparse_coo_tensor(self, def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
mat: torch.Tensor,
has_value: bool = True): has_value: bool = True):
mat = mat.coalesce() mat = mat.coalesce()
index = mat._indices() index = mat._indices()
...@@ -73,20 +60,13 @@ class SparseTensor(object): ...@@ -73,20 +60,13 @@ class SparseTensor(object):
if has_value: if has_value:
value = mat._values() value = mat._values()
return SparseTensor( return SparseTensor(row=row, rowptr=None, col=col, value=value,
row=row, sparse_sizes=(mat.size(0), mat.size(1)),
rowptr=None, is_sorted=True)
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
@classmethod @classmethod
def eye(self, def eye(self, M: int, N: Optional[int] = None,
M: int, options: Optional[torch.Tensor] = None, has_value: bool = True,
N: Optional[int] = None,
options: Optional[torch.Tensor] = None,
has_value: bool = True,
fill_cache: bool = False): fill_cache: bool = False):
N = M if N is None else N N = M if N is None else N
...@@ -104,8 +84,8 @@ class SparseTensor(object): ...@@ -104,8 +84,8 @@ class SparseTensor(object):
value: Optional[torch.Tensor] = None value: Optional[torch.Tensor] = None
if has_value: if has_value:
if options is not None: if options is not None:
value = torch.ones( value = torch.ones(row.numel(), dtype=options.dtype,
row.numel(), dtype=options.dtype, device=row.device) device=row.device)
else: else:
value = torch.ones(row.numel(), device=row.device) value = torch.ones(row.numel(), device=row.device)
...@@ -128,17 +108,9 @@ class SparseTensor(object): ...@@ -128,17 +108,9 @@ class SparseTensor(object):
csr2csc = csc2csr = row csr2csc = csc2csr = row
storage: SparseStorage = SparseStorage( storage: SparseStorage = SparseStorage(
row=row, row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N),
rowptr=rowptr, rowcount=rowcount, colptr=colptr, colcount=colcount,
col=col, csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True)
value=value,
sparse_sizes=(M, N),
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
self = SparseTensor.__new__(SparseTensor) self = SparseTensor.__new__(SparseTensor)
self.storage = storage self.storage = storage
...@@ -181,14 +153,12 @@ class SparseTensor(object): ...@@ -181,14 +153,12 @@ class SparseTensor(object):
def has_value(self) -> bool: def has_value(self) -> bool:
return self.storage.has_value() return self.storage.has_value()
def set_value_(self, def set_value_(self, value: Optional[torch.Tensor],
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
self.storage.set_value_(value, layout) self.storage.set_value_(value, layout)
return self return self
def set_value(self, def set_value(self, value: Optional[torch.Tensor],
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
return self.from_storage(self.storage.set_value(value, layout)) return self.from_storage(self.storage.set_value(value, layout))
...@@ -217,31 +187,23 @@ class SparseTensor(object): ...@@ -217,31 +187,23 @@ class SparseTensor(object):
# Utility functions ####################################################### # Utility functions #######################################################
def fill_value_(self, def fill_value_(self, fill_value: float,
fill_value: float,
options: Optional[torch.Tensor] = None): options: Optional[torch.Tensor] = None):
if options is not None: if options is not None:
value = torch.full((self.nnz(), ), value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
fill_value,
dtype=options.dtype,
device=self.device()) device=self.device())
else: else:
value = torch.full((self.nnz(), ), value = torch.full((self.nnz(), ), fill_value,
fill_value,
device=self.device()) device=self.device())
return self.set_value_(value, layout='coo') return self.set_value_(value, layout='coo')
def fill_value(self, def fill_value(self, fill_value: float,
fill_value: float,
options: Optional[torch.Tensor] = None): options: Optional[torch.Tensor] = None):
if options is not None: if options is not None:
value = torch.full((self.nnz(), ), value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
fill_value,
dtype=options.dtype,
device=self.device()) device=self.device())
else: else:
value = torch.full((self.nnz(), ), value = torch.full((self.nnz(), ), fill_value,
fill_value,
device=self.device()) device=self.device())
return self.set_value(value, layout='coo') return self.set_value(value, layout='coo')
...@@ -308,13 +270,8 @@ class SparseTensor(object): ...@@ -308,13 +270,8 @@ class SparseTensor(object):
N = max(self.size(0), self.size(1)) N = max(self.size(0), self.size(1))
out = SparseTensor( out = SparseTensor(row=row, rowptr=None, col=col, value=value,
row=row, sparse_sizes=(N, N), is_sorted=False)
rowptr=None,
col=col,
value=value,
sparse_sizes=(N, N),
is_sorted=False)
out = out.coalesce(reduce) out = out.coalesce(reduce)
return out return out
...@@ -337,8 +294,7 @@ class SparseTensor(object): ...@@ -337,8 +294,7 @@ class SparseTensor(object):
else: else:
return False return False
def requires_grad_(self, def requires_grad_(self, requires_grad: bool = True,
requires_grad: bool = True,
options: Optional[torch.Tensor] = None): options: Optional[torch.Tensor] = None):
if requires_grad and not self.has_value(): if requires_grad and not self.has_value():
self.fill_value_(1., options=options) self.fill_value_(1., options=options)
...@@ -359,8 +315,8 @@ class SparseTensor(object): ...@@ -359,8 +315,8 @@ class SparseTensor(object):
if value is not None: if value is not None:
return value return value
else: else:
return torch.tensor( return torch.tensor(0., dtype=torch.float,
0., dtype=torch.float, device=self.storage.col().device) device=self.storage.col().device)
def device(self): def device(self):
return self.storage.col().device return self.storage.col().device
...@@ -368,8 +324,7 @@ class SparseTensor(object): ...@@ -368,8 +324,7 @@ class SparseTensor(object):
def cpu(self): def cpu(self):
return self.device_as(torch.tensor(0.), non_blocking=False) return self.device_as(torch.tensor(0.), non_blocking=False)
def cuda(self, def cuda(self, options: Optional[torch.Tensor] = None,
options: Optional[torch.Tensor] = None,
non_blocking: bool = False): non_blocking: bool = False):
if options is not None: if options is not None:
return self.device_as(options, non_blocking) return self.device_as(options, non_blocking)
...@@ -432,19 +387,19 @@ class SparseTensor(object): ...@@ -432,19 +387,19 @@ class SparseTensor(object):
row, col, value = self.coo() row, col, value = self.coo()
if value is not None: if value is not None:
mat = torch.zeros( mat = torch.zeros(self.sizes(), dtype=value.dtype,
self.sizes(), dtype=value.dtype, device=self.device()) device=self.device())
elif options is not None: elif options is not None:
mat = torch.zeros( mat = torch.zeros(self.sizes(), dtype=options.dtype,
self.sizes(), dtype=options.dtype, device=self.device()) device=self.device())
else: else:
mat = torch.zeros(self.sizes(), device=self.device()) mat = torch.zeros(self.sizes(), device=self.device())
if value is not None: if value is not None:
mat[row, col] = value mat[row, col] = value
else: else:
mat[row, col] = torch.ones( mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
self.nnz(), dtype=mat.dtype, device=mat.device) device=mat.device)
return mat return mat
...@@ -454,8 +409,8 @@ class SparseTensor(object): ...@@ -454,8 +409,8 @@ class SparseTensor(object):
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
if value is None: if value is None:
if options is not None: if options is not None:
value = torch.ones( value = torch.ones(self.nnz(), dtype=options.dtype,
self.nnz(), dtype=options.dtype, device=self.device()) device=self.device())
else: else:
value = torch.ones(self.nnz(), device=self.device()) value = torch.ones(self.nnz(), device=self.device())
...@@ -479,7 +434,7 @@ def is_shared(self: SparseTensor) -> bool: ...@@ -479,7 +434,7 @@ def is_shared(self: SparseTensor) -> bool:
def to(self, *args: Optional[List[Any]], def to(self, *args: Optional[List[Any]],
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor: **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
if dtype is not None: if dtype is not None:
self = self.type_as(torch.tensor(0., dtype=dtype)) self = self.type_as(torch.tensor(0., dtype=dtype))
...@@ -580,25 +535,16 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor: ...@@ -580,25 +535,16 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
value = torch.from_numpy(mat.data) value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2] sparse_sizes = mat.shape[:2]
storage = SparseStorage( storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=sparse_sizes, rowcount=None,
rowptr=rowptr, colptr=colptr, colcount=None, csr2csc=None,
col=col, csc2csr=None, is_sorted=True)
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=colptr,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return SparseTensor.from_storage(storage) return SparseTensor.from_storage(storage)
@torch.jit.ignore @torch.jit.ignore
def to_scipy(self: SparseTensor, def to_scipy(self: SparseTensor, layout: Optional[str] = None,
layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix: dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
assert self.dim() == 2 assert self.dim() == 2
layout = get_layout(layout) layout = get_layout(layout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment