Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
0fd9cfe2
Commit
0fd9cfe2
authored
Feb 05, 2020
by
rusty1s
Browse files
cleaner to
parent
40a19d20
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
18 deletions
+3
-18
torch_sparse/storage.py
torch_sparse/storage.py
+0
-2
torch_sparse/tensor.py
torch_sparse/tensor.py
+3
-16
No files found.
torch_sparse/storage.py
View file @
0fd9cfe2
...
@@ -545,7 +545,6 @@ class SparseStorage(object):
...
@@ -545,7 +545,6 @@ class SparseStorage(object):
return
is_pinned
return
is_pinned
@
torch
.
jit
.
ignore
def
share_memory_
(
self
)
->
SparseStorage
:
def
share_memory_
(
self
)
->
SparseStorage
:
row
=
self
.
_row
row
=
self
.
_row
if
row
is
not
None
:
if
row
is
not
None
:
...
@@ -574,7 +573,6 @@ def share_memory_(self) -> SparseStorage:
...
@@ -574,7 +573,6 @@ def share_memory_(self) -> SparseStorage:
csc2csr
.
share_memory_
()
csc2csr
.
share_memory_
()
@
torch
.
jit
.
ignore
def
is_shared
(
self
)
->
bool
:
def
is_shared
(
self
)
->
bool
:
is_shared
=
True
is_shared
=
True
row
=
self
.
_row
row
=
self
.
_row
...
...
torch_sparse/tensor.py
View file @
0fd9cfe2
...
@@ -399,29 +399,18 @@ Dtype = Optional[torch.dtype]
...
@@ -399,29 +399,18 @@ Dtype = Optional[torch.dtype]
Device
=
Optional
[
Union
[
torch
.
device
,
str
]]
Device
=
Optional
[
Union
[
torch
.
device
,
str
]]
@
torch
.
jit
.
ignore
def
share_memory_
(
self
:
SparseTensor
)
->
SparseTensor
:
def
share_memory_
(
self
:
SparseTensor
)
->
SparseTensor
:
self
.
storage
.
share_memory_
()
self
.
storage
.
share_memory_
()
@
torch
.
jit
.
ignore
def
is_shared
(
self
:
SparseTensor
)
->
bool
:
def
is_shared
(
self
:
SparseTensor
)
->
bool
:
return
self
.
storage
.
is_shared
()
return
self
.
storage
.
is_shared
()
@
torch
.
jit
.
ignore
def
to
(
self
,
*
args
:
Optional
[
List
[
Any
]],
def
to
(
self
,
*
args
:
Optional
[
List
[
Any
]],
**
kwargs
:
Optional
[
Dict
[
str
,
Any
]])
->
SparseTensor
:
**
kwargs
:
Optional
[
Dict
[
str
,
Any
]])
->
SparseTensor
:
dtype
:
Dtype
=
getattr
(
kwargs
,
'dtype'
,
None
)
device
,
dtype
,
non_blocking
,
_
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
device
:
Device
=
getattr
(
kwargs
,
'device'
,
None
)
non_blocking
:
bool
=
getattr
(
kwargs
,
'non_blocking'
,
False
)
for
arg
in
args
:
if
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
torch
.
device
):
device
=
arg
if
isinstance
(
arg
,
torch
.
dtype
):
dtype
=
arg
if
dtype
is
not
None
:
if
dtype
is
not
None
:
self
=
self
.
type_as
(
torch
.
tensor
(
0.
,
dtype
=
dtype
))
self
=
self
.
type_as
(
torch
.
tensor
(
0.
,
dtype
=
dtype
))
...
@@ -431,7 +420,6 @@ def to(self, *args: Optional[List[Any]],
...
@@ -431,7 +420,6 @@ def to(self, *args: Optional[List[Any]],
return
self
return
self
@
torch
.
jit
.
ignore
def
__getitem__
(
self
:
SparseTensor
,
index
:
Any
)
->
SparseTensor
:
def
__getitem__
(
self
:
SparseTensor
,
index
:
Any
)
->
SparseTensor
:
index
=
list
(
index
)
if
isinstance
(
index
,
tuple
)
else
[
index
]
index
=
list
(
index
)
if
isinstance
(
index
,
tuple
)
else
[
index
]
# More than one `Ellipsis` is not allowed...
# More than one `Ellipsis` is not allowed...
...
@@ -474,7 +462,6 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
...
@@ -474,7 +462,6 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
return
out
return
out
@
torch
.
jit
.
ignore
def
__repr__
(
self
:
SparseTensor
)
->
str
:
def
__repr__
(
self
:
SparseTensor
)
->
str
:
i
=
' '
*
6
i
=
' '
*
6
row
,
col
,
value
=
self
.
coo
()
row
,
col
,
value
=
self
.
coo
()
...
@@ -564,11 +551,11 @@ SparseTensor.to_scipy = to_scipy
...
@@ -564,11 +551,11 @@ SparseTensor.to_scipy = to_scipy
# Hacky fixes #################################################################
# Hacky fixes #################################################################
# Fix standard operators of `torch.Tensor` for PyTorch<=1.
3
.
# Fix standard operators of `torch.Tensor` for PyTorch<=1.
4
.
# https://github.com/pytorch/pytorch/pull/31769
# https://github.com/pytorch/pytorch/pull/31769
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
<
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
4
):
if
(
TORCH_MAJOR
<
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
=
4
):
def
add
(
self
,
other
):
def
add
(
self
,
other
):
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment