Commit 0fd9cfe2 authored by rusty1s's avatar rusty1s
Browse files

cleaner to

parent 40a19d20
...@@ -545,7 +545,6 @@ class SparseStorage(object): ...@@ -545,7 +545,6 @@ class SparseStorage(object):
return is_pinned return is_pinned
@torch.jit.ignore
def share_memory_(self) -> SparseStorage: def share_memory_(self) -> SparseStorage:
row = self._row row = self._row
if row is not None: if row is not None:
...@@ -574,7 +573,6 @@ def share_memory_(self) -> SparseStorage: ...@@ -574,7 +573,6 @@ def share_memory_(self) -> SparseStorage:
csc2csr.share_memory_() csc2csr.share_memory_()
@torch.jit.ignore
def is_shared(self) -> bool: def is_shared(self) -> bool:
is_shared = True is_shared = True
row = self._row row = self._row
......
...@@ -399,29 +399,18 @@ Dtype = Optional[torch.dtype] ...@@ -399,29 +399,18 @@ Dtype = Optional[torch.dtype]
Device = Optional[Union[torch.device, str]] Device = Optional[Union[torch.device, str]]
@torch.jit.ignore
def share_memory_(self: SparseTensor) -> SparseTensor: def share_memory_(self: SparseTensor) -> SparseTensor:
self.storage.share_memory_() self.storage.share_memory_()
@torch.jit.ignore
def is_shared(self: SparseTensor) -> bool: def is_shared(self: SparseTensor) -> bool:
return self.storage.is_shared() return self.storage.is_shared()
@torch.jit.ignore
def to(self, *args: Optional[List[Any]], def to(self, *args: Optional[List[Any]],
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor: **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
dtype: Dtype = getattr(kwargs, 'dtype', None) device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
device: Device = getattr(kwargs, 'device', None)
non_blocking: bool = getattr(kwargs, 'non_blocking', False)
for arg in args:
if isinstance(arg, str) or isinstance(arg, torch.device):
device = arg
if isinstance(arg, torch.dtype):
dtype = arg
if dtype is not None: if dtype is not None:
self = self.type_as(torch.tensor(0., dtype=dtype)) self = self.type_as(torch.tensor(0., dtype=dtype))
...@@ -431,7 +420,6 @@ def to(self, *args: Optional[List[Any]], ...@@ -431,7 +420,6 @@ def to(self, *args: Optional[List[Any]],
return self return self
@torch.jit.ignore
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor: def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
index = list(index) if isinstance(index, tuple) else [index] index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed... # More than one `Ellipsis` is not allowed...
...@@ -474,7 +462,6 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor: ...@@ -474,7 +462,6 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
return out return out
@torch.jit.ignore
def __repr__(self: SparseTensor) -> str: def __repr__(self: SparseTensor) -> str:
i = ' ' * 6 i = ' ' * 6
row, col, value = self.coo() row, col, value = self.coo()
...@@ -564,11 +551,11 @@ SparseTensor.to_scipy = to_scipy ...@@ -564,11 +551,11 @@ SparseTensor.to_scipy = to_scipy
# Hacky fixes ################################################################# # Hacky fixes #################################################################
# Fix standard operators of `torch.Tensor` for PyTorch<=1.3. # Fix standard operators of `torch.Tensor` for PyTorch<=1.4.
# https://github.com/pytorch/pytorch/pull/31769 # https://github.com/pytorch/pytorch/pull/31769
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR < 4): if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR <= 4):
def add(self, other): def add(self, other):
if torch.is_tensor(other) or is_scalar(other): if torch.is_tensor(other) or is_scalar(other):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment