Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
8fb5428b
Commit
8fb5428b
authored
May 21, 2020
by
rusty1s
Browse files
fixes
parent
78d9af48
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
26 deletions
+8
-26
torch_sparse/sample.py
torch_sparse/sample.py
+5
-2
torch_sparse/tensor.py
torch_sparse/tensor.py
+3
-24
No files found.
torch_sparse/sample.py
View file @
8fb5428b
...
...
@@ -25,13 +25,16 @@ def sample(src: SparseTensor, num_neighbors: int,
def
sample_adj
(
src
:
SparseTensor
,
subset
:
torch
.
Tensor
,
num_neighbors
:
int
,
replace
:
bool
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
rowptr
,
col
,
_
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
rowcount
=
src
.
storage
.
rowcount
()
rowptr
,
col
,
n_id
,
e_id
=
torch
.
ops
.
torch_sparse
.
sample_adj
(
rowptr
,
col
,
rowcount
,
subset
,
num_neighbors
,
replace
)
out
=
SparseTensor
(
rowptr
=
rowptr
,
row
=
None
,
col
=
col
,
value
=
e_id
,
if
value
is
not
None
:
value
=
value
[
e_id
]
out
=
SparseTensor
(
rowptr
=
rowptr
,
row
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
subset
.
size
(
0
),
n_id
.
size
(
0
)),
is_sorted
=
True
)
...
...
torch_sparse/tensor.py
View file @
8fb5428b
...
...
@@ -409,7 +409,7 @@ class SparseTensor(object):
# Conversions #############################################################
def
to_dense
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
def
to_dense
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
row
,
col
,
value
=
self
.
coo
()
if
value
is
not
None
:
...
...
@@ -541,8 +541,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
csr_matrix
,
scipy
.
sparse
.
csc_matrix
]
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
csr_matrix
,
scipy
.
sparse
.
csc_matrix
]
@
torch
.
jit
.
ignore
...
...
@@ -600,24 +600,3 @@ def to_scipy(self: SparseTensor, layout: Optional[str] = None,
SparseTensor
.
from_scipy
=
from_scipy
SparseTensor
.
to_scipy
=
to_scipy
# Hacky fixes #################################################################
# Fix standard operators of `torch.Tensor` for PyTorch<=1.3.
# https://github.com/pytorch/pytorch/pull/31769
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
<
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<=
3
):
def
add
(
self
,
other
):
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
return
self
.
add
(
other
)
return
NotImplemented
def
mul
(
self
,
other
):
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
return
self
.
mul
(
other
)
return
NotImplemented
torch
.
Tensor
.
__add__
=
add
torch
.
Tensor
.
__mul__
=
mul
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment