Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
e7f4ef9f
Commit
e7f4ef9f
authored
Jan 26, 2020
by
rusty1s
Browse files
fix index select
parent
47b719bb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
28 deletions
+59
-28
test/test_index_select.py
test/test_index_select.py
+27
-0
torch_sparse/index_select.py
torch_sparse/index_select.py
+30
-26
torch_sparse/storage.py
torch_sparse/storage.py
+1
-1
torch_sparse/tensor.py
torch_sparse/tensor.py
+1
-1
No files found.
test/test_index_select.py
0 → 100644
View file @
e7f4ef9f
import
time
from
itertools
import
product
from
scipy.io
import
loadmat
import
numpy
as
np
import
pytest
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.add
import
sparse_add
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
'cpu'
]
dtypes
=
[
torch
.
float
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_index_select
(
dtype
,
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
])
col
=
torch
.
tensor
([
0
,
1
,
1
,
2
,
1
])
mat
=
SparseTensor
(
row
=
row
,
col
=
col
)
print
()
print
(
mat
.
to_dense
())
pass
mat
=
mat
.
index_select
(
0
,
torch
.
tensor
([
0
,
2
]))
print
(
mat
.
to_dense
())
torch_sparse/index_select.py
View file @
e7f4ef9f
import
torch
import
torch
from
torch_scatter
import
gather_csr
from
torch_sparse.storage
import
get_layout
from
torch_sparse.storage
import
get_layout
...
@@ -9,56 +10,58 @@ def index_select(src, dim, idx):
...
@@ -9,56 +10,58 @@ def index_select(src, dim, idx):
assert
idx
.
dim
()
==
1
assert
idx
.
dim
()
==
1
if
dim
==
0
:
if
dim
==
0
:
(
row
,
col
)
,
value
=
src
.
c
oo
()
old_rowptr
,
col
,
value
=
src
.
c
sr
()
rowcount
=
src
.
storage
.
rowcount
rowcount
=
src
.
storage
.
rowcount
old_rowptr
=
src
.
storage
.
rowptr
rowcount
=
rowcount
[
idx
]
rowcount
=
rowcount
[
idx
]
tmp
=
torch
.
arange
(
rowcount
.
size
(
0
),
device
=
rowcount
.
device
)
row
=
tmp
.
repeat_interleave
(
rowcount
)
# Creates an "arange interleave" tensor of col indices.
rowptr
=
col
.
new_zeros
(
idx
.
size
(
0
)
+
1
)
rowptr
=
torch
.
cat
([
row
.
new_zeros
(
1
),
rowcount
.
cumsum
(
0
)],
dim
=
0
)
torch
.
cumsum
(
rowcount
,
dim
=
0
,
out
=
rowptr
[
1
:])
row
=
torch
.
arange
(
idx
.
size
(
0
),
device
=
col
.
device
).
repeat_interleave
(
rowcount
)
perm
=
torch
.
arange
(
row
.
size
(
0
),
device
=
row
.
device
)
perm
=
torch
.
arange
(
row
.
size
(
0
),
device
=
row
.
device
)
perm
+=
(
old_rowptr
[
idx
]
-
rowptr
[:
-
1
]
)[
row
]
perm
+=
gather_csr
(
old_rowptr
[
idx
]
-
rowptr
[:
-
1
]
,
row
ptr
)
col
=
col
[
perm
]
col
=
col
[
perm
]
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
src
.
has_value
():
if
src
.
has_value
():
value
=
value
[
perm
]
value
=
value
[
perm
]
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
sparse_size
=
torch
.
Size
([
idx
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
storage
=
src
.
storage
.
__class__
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
rowcount
=
rowcount
,
rowptr
=
rowptr
,
value
=
value
,
sparse_size
=
sparse_size
,
is_sorted
=
True
)
rowcount
=
rowcount
,
is_sorted
=
True
)
elif
dim
==
1
:
elif
dim
==
1
:
old_colptr
,
row
,
value
=
src
.
csc
()
old_colptr
,
row
,
value
=
src
.
csc
()
colcount
=
src
.
storage
.
colcount
colcount
=
src
.
storage
.
colcount
colcount
=
colcount
[
idx
]
colcount
=
colcount
[
idx
]
tmp
=
torch
.
arange
(
colcount
.
size
(
0
),
device
=
row
.
device
)
col
=
torch
.
arange
(
idx
.
size
(
0
),
col
=
tmp
.
repeat_interleave
(
colcount
)
device
=
row
.
device
).
repeat_interleave
(
colcount
)
colptr
=
row
.
new_zeros
(
idx
.
size
(
0
)
+
1
)
torch
.
cumsum
(
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
# Creates an "arange interleave" tensor of row indices.
colptr
=
torch
.
cat
([
col
.
new_zeros
(
1
),
colcount
.
cumsum
(
0
)],
dim
=
0
)
perm
=
torch
.
arange
(
col
.
size
(
0
),
device
=
col
.
device
)
perm
=
torch
.
arange
(
col
.
size
(
0
),
device
=
col
.
device
)
perm
+=
(
old_colptr
[
idx
]
-
colptr
[:
-
1
]
)[
col
]
perm
+=
gather_csr
(
old_colptr
[
idx
]
-
colptr
[:
-
1
]
,
col
ptr
)
row
=
row
[
perm
]
row
=
row
[
perm
]
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
csc2csr
=
(
idx
.
size
(
0
)
*
row
+
col
).
argsort
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)[:,
csc2csr
]
row
,
col
=
row
[
csc2csr
],
col
[
csc2csr
]
if
src
.
has_value
():
if
src
.
has_value
():
value
=
value
[
perm
][
csc2csr
]
value
=
value
[
perm
][
csc2csr
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
idx
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
colcount
=
colcount
,
colptr
=
colptr
,
sparse_size
=
sparse_size
,
colptr
=
colptr
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
else
:
else
:
storage
=
src
.
storage
.
apply_value
(
storage
=
src
.
storage
.
apply_value
(
...
@@ -73,14 +76,15 @@ def index_select_nnz(src, idx, layout=None):
...
@@ -73,14 +76,15 @@ def index_select_nnz(src, idx, layout=None):
if
get_layout
(
layout
)
==
'csc'
:
if
get_layout
(
layout
)
==
'csc'
:
idx
=
idx
[
src
.
storage
.
csc2csr
]
idx
=
idx
[
src
.
storage
.
csc2csr
]
index
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
row
,
col
=
row
[
idx
],
col
[
idx
]
index
=
index
[:,
idx
]
if
src
.
has_value
():
if
src
.
has_value
():
value
=
value
[
idx
]
value
=
value
[
idx
]
# There is no other information we can maintain...
# There is no other information we can maintain...
storage
=
src
.
storage
.
__class__
(
index
,
value
,
src
.
sparse_size
(),
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_size
=
src
.
sparse_size
(),
is_sorted
=
True
)
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
torch_sparse/storage.py
View file @
e7f4ef9f
...
@@ -80,7 +80,7 @@ class SparseStorage(object):
...
@@ -80,7 +80,7 @@ class SparseStorage(object):
assert
col
.
dim
()
==
1
assert
col
.
dim
()
==
1
if
sparse_size
is
None
:
if
sparse_size
is
None
:
M
=
rowptr
.
numel
()
-
1
if
row
ptr
is
None
else
row
.
max
().
item
()
+
1
M
=
rowptr
.
numel
()
-
1
if
row
is
None
else
row
.
max
().
item
()
+
1
N
=
col
.
max
().
item
()
+
1
N
=
col
.
max
().
item
()
+
1
sparse_size
=
torch
.
Size
([
M
,
N
])
sparse_size
=
torch
.
Size
([
M
,
N
])
...
...
torch_sparse/tensor.py
View file @
e7f4ef9f
...
@@ -355,7 +355,7 @@ class SparseTensor(object):
...
@@ -355,7 +355,7 @@ class SparseTensor(object):
device
=
self
.
device
,
device
=
self
.
device
,
requires_grad
=
requires_grad
)
requires_grad
=
requires_grad
)
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
"csr"
):
def
to_scipy
(
self
,
layout
=
None
,
dtype
=
None
):
assert
self
.
dim
()
==
2
assert
self
.
dim
()
==
2
layout
=
get_layout
(
layout
)
layout
=
get_layout
(
layout
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment