Commit 1f175220 authored by rusty1s's avatar rusty1s
Browse files

fix cuda

parent 24c599ea
...@@ -321,7 +321,8 @@ class SparseTensor(object): ...@@ -321,7 +321,8 @@ class SparseTensor(object):
def cpu(self): def cpu(self):
return self.device_as(torch.tensor(0.), non_blocking=False) return self.device_as(torch.tensor(0.), non_blocking=False)
def cuda(self, options=Optional[torch.Tensor], non_blocking: bool = False): def cuda(self, options: Optional[torch.Tensor] = None,
non_blocking: bool = False):
if options is not None: if options is not None:
return self.device_as(options, non_blocking) return self.device_as(options, non_blocking)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment