Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
8b290612
Commit
8b290612
authored
Jun 12, 2020
by
rusty1s
Browse files
fix sparse_reshape in PyTorch 1.4.0
parent
73a89efb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
91 deletions
+5
-91
test/test_padding.py
test/test_padding.py
+0
-90
torch_sparse/storage.py
torch_sparse/storage.py
+5
-1
No files found.
test/test_padding.py
deleted
100644 → 0
View file @
73a89efb
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse
import
SparseTensor
,
padded_index_select
from
.utils
import
grad_dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_padded_index_select
(
dtype
,
device
):
row
=
torch
.
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
3
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
2
,
3
,
1
,
3
,
2
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
3
,
5
],
device
=
device
)
data
=
adj
.
padded_index
(
binptr
)
node_perm
,
row_perm
,
col_perm
,
mask
,
node_size
,
edge_size
=
data
assert
node_perm
.
tolist
()
==
[
2
,
3
,
0
,
1
]
assert
row_perm
.
tolist
()
==
[
2
,
2
,
3
,
-
1
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
-
1
]
assert
col_perm
.
tolist
()
==
[
1
,
3
,
2
,
-
1
,
0
,
1
,
2
,
3
,
0
,
2
,
3
,
-
1
]
assert
mask
.
long
().
tolist
()
==
[
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
]
assert
node_size
==
[
2
,
2
]
assert
edge_size
==
[
4
,
8
]
x
=
tensor
([
0
,
1
,
2
,
3
],
dtype
,
device
).
view
(
-
1
,
1
).
requires_grad_
()
x_j
=
padded_index_select
(
x
,
col_perm
)
assert
x_j
.
flatten
().
tolist
()
==
[
1
,
3
,
2
,
0
,
0
,
1
,
2
,
3
,
0
,
2
,
3
,
0
]
grad_out
=
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
dtype
,
device
)
x_j
.
backward
(
grad_out
.
view
(
-
1
,
1
))
assert
x
.
grad
.
flatten
().
tolist
()
==
[
12
,
5
,
17
,
18
]
def
test_padded_index_select_runtime
():
return
from
torch_geometric.datasets
import
Planetoid
device
=
torch
.
device
(
'cuda'
)
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
data
=
dataset
[
0
]
row
,
col
=
data
.
edge_index
.
to
(
device
)
adj
=
SparseTensor
(
row
=
row
,
col
=
col
)
rowcount
=
adj
.
storage
.
rowcount
().
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
4
,
11
,
30
,
50
,
80
,
120
,
140
,
2000
]).
to
(
device
)
x
=
torch
.
randn
(
adj
.
size
(
0
),
512
).
to
(
device
)
data
=
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
col
,
rowcount
,
binptr
)
node_perm
,
row_perm
,
col_perm
,
mask
,
node_sizes
,
edge_sizes
=
data
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
col_perm
,
torch
.
tensor
(
0.
))
outs
=
out
.
split
(
edge_sizes
)
for
out
,
size
in
zip
(
outs
,
node_sizes
):
print
(
out
.
view
(
size
,
-
1
,
x
.
size
(
-
1
)).
shape
)
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
col
,
rowcount
,
binptr
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'padded index'
,
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
col_perm
,
torch
.
tensor
(
0.
))
out
.
split
(
edge_sizes
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'padded index select'
,
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
x
.
index_select
(
0
,
col
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'index_select'
,
start
.
elapsed_time
(
end
))
torch_sparse/storage.py
View file @
8b290612
...
...
@@ -7,6 +7,9 @@ from torch_sparse.utils import Final
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
# FIXME: Remove once `/` on `LongTensors` is officially removed from PyTorch.
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
)
def
get_layout
(
layout
:
Optional
[
str
]
=
None
)
->
str
:
if
layout
is
None
:
...
...
@@ -277,8 +280,9 @@ class SparseStorage(object):
idx
=
self
.
sparse_size
(
1
)
*
self
.
row
()
+
self
.
col
()
row
=
idx
/
/
num_cols
row
=
idx
/
num_cols
col
=
idx
%
num_cols
assert
row
.
dtype
==
torch
.
long
and
col
.
dtype
==
torch
.
long
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
self
.
_value
,
sparse_sizes
=
(
num_rows
,
num_cols
),
rowcount
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment