Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
8481d0d6
Commit
8481d0d6
authored
Jan 26, 2020
by
rusty1s
Browse files
masked select fix
parent
e7f4ef9f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
31 deletions
+30
-31
torch_sparse/masked_select.py
torch_sparse/masked_select.py
+30
-31
No files found.
torch_sparse/masked_select.py
View file @
8481d0d6
...
@@ -10,56 +10,54 @@ def masked_select(src, dim, mask):
...
@@ -10,56 +10,54 @@ def masked_select(src, dim, mask):
storage
=
src
.
storage
storage
=
src
.
storage
if
dim
==
0
:
if
dim
==
0
:
(
row
,
col
)
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
rowcount
=
src
.
storage
.
rowcount
rowcount
=
src
.
storage
.
rowcount
row_mask
=
mask
[
row
]
rowcount
=
rowcount
[
mask
]
rowcount
=
rowcount
[
mask
]
idx
=
torch
.
arange
(
rowcount
.
size
(
0
),
device
=
rowcount
.
device
)
row
=
idx
.
repeat_interleave
(
rowcount
)
mask
=
mask
[
row
]
col
=
col
[
row_mask
]
row
=
torch
.
arange
(
rowcount
.
size
(
0
),
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
device
=
row
.
device
).
repeat_interleave
(
rowcount
)
col
=
col
[
mask
]
if
src
.
has_value
():
if
src
.
has_value
():
value
=
value
[
row_
mask
]
value
=
value
[
mask
]
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
sparse_size
=
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
elif
dim
==
1
:
elif
dim
==
1
:
row
,
col
,
value
=
src
.
coo
()
csr2csc
=
src
.
storage
.
csr2csc
csr2csc
=
src
.
storage
.
csr2csc
row
=
src
.
storage
.
row
[
csr2csc
]
row
,
col
=
row
[
csr2csc
],
col
[
csr2csc
]
col
=
src
.
storage
.
col
[
csr2csc
]
colcount
=
src
.
storage
.
colcount
colcount
=
src
.
storage
.
colcount
col_mask
=
mask
[
col
]
colcount
=
colcount
[
mask
]
colcount
=
colcount
[
mask
]
tmp
=
torch
.
arange
(
colcount
.
size
(
0
),
device
=
row
.
device
)
col
=
tmp
.
repeat_interleave
(
colcount
)
mask
=
mask
[
col
]
row
=
row
[
col_mask
]
col
=
torch
.
arange
(
colcount
.
size
(
0
),
device
=
col
.
device
).
repeat_interleave
(
colcount
)
row
=
row
[
mask
]
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)[:,
csc2csr
]
row
,
col
=
row
[
csc2csr
],
col
[
csc2csr
]
value
=
src
.
storage
.
value
if
src
.
has_value
():
if
src
.
has_value
():
value
=
value
[
csr2csc
][
col_
mask
][
csc2csr
]
value
=
value
[
csr2csc
][
mask
][
csc2csr
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
index
,
sparse_size
=
sparse_size
,
value
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
sparse_size
,
is_sorted
=
True
)
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
else
:
else
:
idx
=
mask
.
nonzero
().
view
(
-
1
)
idx
=
mask
.
nonzero
().
view
(
-
1
)
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
index_select
(
storage
=
src
.
storage
.
apply_value
(
dim
-
1
,
idx
))
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
...
@@ -70,14 +68,15 @@ def masked_select_nnz(src, mask, layout=None):
...
@@ -70,14 +68,15 @@ def masked_select_nnz(src, mask, layout=None):
if
get_layout
(
layout
)
==
'csc'
:
if
get_layout
(
layout
)
==
'csc'
:
mask
=
mask
[
src
.
storage
.
csc2csr
]
mask
=
mask
[
src
.
storage
.
csc2csr
]
index
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
row
,
col
=
row
[
mask
],
col
[
mask
]
index
=
index
[:,
mask
]
if
src
.
has_value
():
if
src
.
has_value
():
value
=
value
[
mask
]
value
=
value
[
mask
]
# There is no other information we can maintain...
# There is no other information we can maintain...
storage
=
src
.
storage
.
__class__
(
storage
=
src
.
storage
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
index
,
value
,
src
.
sparse_size
(),
is_sorted
=
True
)
sparse_size
=
src
.
sparse_size
(),
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment