Commit f9b00093 authored by rusty1s's avatar rusty1s
Browse files

typing

parent 34b25b3c
from textwrap import indent from textwrap import indent
from typing import Optional, List, Tuple, Union from typing import Optional, List, Tuple, Dict, Union, Any
import torch import torch
import scipy.sparse import scipy.sparse
...@@ -404,7 +404,9 @@ def is_shared(self: SparseTensor) -> bool: ...@@ -404,7 +404,9 @@ def is_shared(self: SparseTensor) -> bool:
@torch.jit.ignore @torch.jit.ignore
def to(self, *args, **kwargs): def to(self, *args: Optional[List[Any]],
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
dtype: Dtype = getattr(kwargs, 'dtype', None) dtype: Dtype = getattr(kwargs, 'dtype', None)
device: Device = getattr(kwargs, 'device', None) device: Device = getattr(kwargs, 'device', None)
non_blocking: bool = getattr(kwargs, 'non_blocking', False) non_blocking: bool = getattr(kwargs, 'non_blocking', False)
...@@ -424,8 +426,7 @@ def to(self, *args, **kwargs): ...@@ -424,8 +426,7 @@ def to(self, *args, **kwargs):
@torch.jit.ignore @torch.jit.ignore
def __getitem__(self, index): def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
raise NotImplementedError
index = list(index) if isinstance(index, tuple) else [index] index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed... # More than one `Ellipsis` is not allowed...
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1: if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
...@@ -468,7 +469,7 @@ def __getitem__(self, index): ...@@ -468,7 +469,7 @@ def __getitem__(self, index):
@torch.jit.ignore @torch.jit.ignore
def __repr__(self): def __repr__(self: SparseTensor) -> str:
i = ' ' * 6 i = ' ' * 6
row, col, value = self.coo() row, col, value = self.coo()
infos = [] infos = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment