Commit 5f4f9c55 authored by rusty1s's avatar rusty1s
Browse files

cpu and cuda methods

parent 3dbd2282
......@@ -455,6 +455,15 @@ def to(self, *args: Optional[List[Any]],
return self
def cpu(self) -> SparseTensor:
return self.device_as(torch.tensor(0., device='cpu'))
def cuda(self, device: Optional[Union[int, str]] = None,
non_blocking: bool = False):
return self.device_as(torch.tensor(0., device=device or 'cuda'))
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed...
......@@ -523,6 +532,8 @@ def __repr__(self: SparseTensor) -> str:
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
SparseTensor.cpu = cpu
SparseTensor.cuda = cuda
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment