Commit 3c415d25 authored by rusty1s's avatar rusty1s
Browse files

torch.device / torch.dtype args

parent 698be79e
...@@ -9,17 +9,18 @@ from .utils import dtypes, devices ...@@ -9,17 +9,18 @@ from .utils import dtypes, devices
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_eye(dtype, device): def test_eye(dtype, device):
options = torch.tensor(0, dtype=dtype, device=device) mat = SparseTensor.eye(3, dtype=dtype, device=device)
assert mat.storage.col().device == device
mat = SparseTensor.eye(3, options=options)
assert mat.storage.sparse_sizes() == (3, 3) assert mat.storage.sparse_sizes() == (3, 3)
assert mat.storage.row().tolist() == [0, 1, 2] assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
assert mat.storage.col().tolist() == [0, 1, 2] assert mat.storage.col().tolist() == [0, 1, 2]
assert mat.storage.value().tolist() == [1, 1, 1] assert mat.storage.value().tolist() == [1, 1, 1]
assert mat.storage.value().dtype == dtype
assert mat.storage.num_cached_keys() == 0 assert mat.storage.num_cached_keys() == 0
mat = SparseTensor.eye(3, options=options, has_value=False) mat = SparseTensor.eye(3, has_value=False)
assert mat.storage.col().device == device
assert mat.storage.sparse_sizes() == (3, 3) assert mat.storage.sparse_sizes() == (3, 3)
assert mat.storage.row().tolist() == [0, 1, 2] assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
...@@ -27,7 +28,8 @@ def test_eye(dtype, device): ...@@ -27,7 +28,8 @@ def test_eye(dtype, device):
assert mat.storage.value() is None assert mat.storage.value() is None
assert mat.storage.num_cached_keys() == 0 assert mat.storage.num_cached_keys() == 0
mat = SparseTensor.eye(3, 4, options=options, fill_cache=True) mat = SparseTensor.eye(3, 4, fill_cache=True)
assert mat.storage.col().device == device
assert mat.storage.sparse_sizes() == (3, 4) assert mat.storage.sparse_sizes() == (3, 4)
assert mat.storage.row().tolist() == [0, 1, 2] assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] assert mat.storage.rowptr().tolist() == [0, 1, 2, 3]
...@@ -39,7 +41,8 @@ def test_eye(dtype, device): ...@@ -39,7 +41,8 @@ def test_eye(dtype, device):
assert mat.storage.csr2csc().tolist() == [0, 1, 2] assert mat.storage.csr2csc().tolist() == [0, 1, 2]
assert mat.storage.csc2csr().tolist() == [0, 1, 2] assert mat.storage.csc2csr().tolist() == [0, 1, 2]
mat = SparseTensor.eye(4, 3, options=options, fill_cache=True) mat = SparseTensor.eye(4, 3, fill_cache=True)
assert mat.storage.col().device == device
assert mat.storage.sparse_sizes() == (4, 3) assert mat.storage.sparse_sizes() == (4, 3)
assert mat.storage.row().tolist() == [0, 1, 2] assert mat.storage.row().tolist() == [0, 1, 2]
assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3] assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3]
......
...@@ -78,7 +78,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, ...@@ -78,7 +78,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
return src.from_storage(storage) return src.from_storage(storage)
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor: def fill_diag(src: SparseTensor, fill_value: float,
k: int = 0) -> SparseTensor:
num_diag = min(src.sparse_size(0), src.sparse_size(1) - k) num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
if k < 0: if k < 0:
num_diag = min(src.sparse_size(0) + k, src.sparse_size(1)) num_diag = min(src.sparse_size(0) + k, src.sparse_size(1))
......
...@@ -459,13 +459,14 @@ class SparseStorage(object): ...@@ -459,13 +459,14 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.clone() csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True)
def type_as(self, tensor=torch.Tensor): def type_as(self, tensor: torch.Tensor):
value = self._value value = self._value
if value is not None: if value is not None:
if tensor.dtype == value.dtype: if tensor.dtype == value.dtype:
...@@ -504,12 +505,49 @@ class SparseStorage(object): ...@@ -504,12 +505,49 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking) csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True) csc2csr=csc2csr, is_sorted=True)
def cuda(self):
new_col = self._col.cuda()
if new_col.device == self._col.device:
return self
row = self._row
if row is not None:
row = row.cuda()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.cuda()
value = self._value
if value is not None:
value = value.cuda()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.cuda()
colptr = self._colptr
if colptr is not None:
colptr = colptr.cuda()
colcount = self._colcount
if colcount is not None:
colcount = colcount.cuda()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.cuda()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.cuda()
return SparseStorage(row=row, rowptr=rowptr, col=new_col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def pin_memory(self): def pin_memory(self):
row = self._row row = self._row
if row is not None: if row is not None:
...@@ -536,6 +574,7 @@ class SparseStorage(object): ...@@ -536,6 +574,7 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.pin_memory() csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value, return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes, sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr,
......
...@@ -73,29 +73,22 @@ class SparseTensor(object): ...@@ -73,29 +73,22 @@ class SparseTensor(object):
is_sorted=True) is_sorted=True)
@classmethod @classmethod
def eye(self, M: int, N: Optional[int] = None, def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
options: Optional[torch.Tensor] = None, has_value: bool = True, dtype: Optional[int] = None, device: Optional[torch.device] = None,
fill_cache: bool = False): fill_cache: bool = False):
N = M if N is None else N N = M if N is None else N
if options is not None: row = torch.arange(min(M, N), device=device)
row = torch.arange(min(M, N), device=options.device)
else:
row = torch.arange(min(M, N))
col = row col = row
rowptr = torch.arange(M + 1, dtype=torch.long, device=row.device) rowptr = torch.arange(M + 1, device=row.device)
if M > N: if M > N:
rowptr[N + 1:] = N rowptr[N + 1:] = N
value: Optional[torch.Tensor] = None value: Optional[torch.Tensor] = None
if has_value: if has_value:
if options is not None: value = torch.ones(row.numel(), dtype=dtype, device=row.device)
value = torch.ones(row.numel(), dtype=options.dtype,
device=row.device)
else:
value = torch.ones(row.numel(), device=row.device)
rowcount: Optional[torch.Tensor] = None rowcount: Optional[torch.Tensor] = None
colptr: Optional[torch.Tensor] = None colptr: Optional[torch.Tensor] = None
...@@ -131,7 +124,7 @@ class SparseTensor(object): ...@@ -131,7 +124,7 @@ class SparseTensor(object):
return self.from_storage(self.storage.clone()) return self.from_storage(self.storage.clone())
def type_as(self, tensor=torch.Tensor): def type_as(self, tensor=torch.Tensor):
value = self.storage._value value = self.storage.value()
if value is None or tensor.dtype == value.dtype: if value is None or tensor.dtype == value.dtype:
return self return self
return self.from_storage(self.storage.type_as(tensor)) return self.from_storage(self.storage.type_as(tensor))
...@@ -199,23 +192,13 @@ class SparseTensor(object): ...@@ -199,23 +192,13 @@ class SparseTensor(object):
# Utility functions ####################################################### # Utility functions #######################################################
def fill_value_(self, fill_value: float, def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
options: Optional[torch.Tensor] = None): value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
if options is not None:
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
device=self.device())
else:
value = torch.full((self.nnz(), ), fill_value,
device=self.device()) device=self.device())
return self.set_value_(value, layout='coo') return self.set_value_(value, layout='coo')
def fill_value(self, fill_value: float, def fill_value(self, fill_value: float, dtype: Optional[int] = None):
options: Optional[torch.Tensor] = None): value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
if options is not None:
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
device=self.device())
else:
value = torch.full((self.nnz(), ), fill_value,
device=self.device()) device=self.device())
return self.set_value(value, layout='coo') return self.set_value(value, layout='coo')
...@@ -320,9 +303,9 @@ class SparseTensor(object): ...@@ -320,9 +303,9 @@ class SparseTensor(object):
return False return False
def requires_grad_(self, requires_grad: bool = True, def requires_grad_(self, requires_grad: bool = True,
options: Optional[torch.Tensor] = None): dtype: Optional[int] = None):
if requires_grad and not self.has_value(): if requires_grad and not self.has_value():
self.fill_value_(1., options=options) self.fill_value_(1., dtype)
value = self.storage.value() value = self.storage.value()
if value is not None: if value is not None:
...@@ -335,36 +318,25 @@ class SparseTensor(object): ...@@ -335,36 +318,25 @@ class SparseTensor(object):
def is_pinned(self) -> bool: def is_pinned(self) -> bool:
return self.storage.is_pinned() return self.storage.is_pinned()
def options(self) -> torch.Tensor:
value = self.storage.value()
if value is not None:
return value
else:
return torch.tensor(0., dtype=torch.float,
device=self.storage.col().device)
def device(self): def device(self):
return self.storage.col().device return self.storage.col().device
def cpu(self): def cpu(self):
return self.device_as(torch.tensor(0.), non_blocking=False) return self.device_as(torch.tensor(0), non_blocking=False)
def cuda(self, options: Optional[torch.Tensor] = None, def cuda(self):
non_blocking: bool = False): return self.from_storage(self.storage.cuda())
if options is not None:
return self.device_as(options, non_blocking)
else:
options = torch.tensor(0.).cuda()
return self.device_as(options, non_blocking)
def is_cuda(self) -> bool: def is_cuda(self) -> bool:
return self.storage.col().is_cuda return self.storage.col().is_cuda
def dtype(self): def dtype(self):
return self.options().dtype value = self.storage.value()
return value.dtype if value is not None else torch.float
def is_floating_point(self) -> bool: def is_floating_point(self) -> bool:
return torch.is_floating_point(self.options()) value = self.storage.value()
return torch.is_floating_point(value) if value is not None else True
def bfloat16(self): def bfloat16(self):
return self.type_as( return self.type_as(
...@@ -408,17 +380,14 @@ class SparseTensor(object): ...@@ -408,17 +380,14 @@ class SparseTensor(object):
# Conversions ############################################################# # Conversions #############################################################
def to_dense(self, options: Optional[torch.Tensor] = None) -> torch.Tensor: def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
row, col, value = self.coo() row, col, value = self.coo()
if value is not None: if value is not None:
mat = torch.zeros(self.sizes(), dtype=value.dtype, mat = torch.zeros(self.sizes(), dtype=value.dtype,
device=self.device()) device=self.device())
elif options is not None:
mat = torch.zeros(self.sizes(), dtype=options.dtype,
device=self.device())
else: else:
mat = torch.zeros(self.sizes(), device=self.device()) mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
if value is not None: if value is not None:
mat[row, col] = value mat[row, col] = value
...@@ -428,25 +397,18 @@ class SparseTensor(object): ...@@ -428,25 +397,18 @@ class SparseTensor(object):
return mat return mat
def to_torch_sparse_coo_tensor(self, def to_torch_sparse_coo_tensor(self, dtype: Optional[int] = None):
options: Optional[torch.Tensor] = None):
row, col, value = self.coo() row, col, value = self.coo()
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
if value is None: if value is None:
if options is not None: value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
value = torch.ones(self.nnz(), dtype=options.dtype,
device=self.device())
else:
value = torch.ones(self.nnz(), device=self.device())
return torch.sparse_coo_tensor(index, value, self.sizes()) return torch.sparse_coo_tensor(index, value, self.sizes())
# Python Bindings ############################################################# # Python Bindings #############################################################
Dtype = Optional[torch.dtype]
Device = Optional[Union[torch.device, str]]
def share_memory_(self: SparseTensor) -> SparseTensor: def share_memory_(self: SparseTensor) -> SparseTensor:
self.storage.share_memory_() self.storage.share_memory_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment