Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
9216364c
Commit
9216364c
authored
Dec 19, 2019
by
rusty1s
Browse files
__getitem__ numpy notation
parent
6b9127a0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
112 additions
and
37 deletions
+112
-37
torch_sparse/index_select.py
torch_sparse/index_select.py
+14
-10
torch_sparse/masked_select.py
torch_sparse/masked_select.py
+14
-10
torch_sparse/narrow.py
torch_sparse/narrow.py
+18
-9
torch_sparse/tensor.py
torch_sparse/tensor.py
+66
-8
No files found.
torch_sparse/index_select.py
View file @
9216364c
...
...
@@ -16,7 +16,7 @@ def arange_interleave(start, repeat):
def
index_select
(
src
,
dim
,
idx
):
dim
=
src
.
dim
()
-
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
assert
idx
.
dim
()
==
1
idx
=
idx
.
to
(
src
.
device
)
...
...
@@ -38,8 +38,8 @@ def index_select(src, dim, idx):
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
elif
dim
==
1
:
colptr
,
row
,
value
=
src
.
csc
()
...
...
@@ -58,13 +58,17 @@ def index_select(src, dim, idx):
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
else
:
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
return
src
.
from_storage
(
storage
)
...
...
@@ -82,7 +86,7 @@ def index_select_nnz(src, idx, layout=None):
value
=
value
[
idx
]
# There is no other information we can maintain...
storage
=
src
.
storage
.
__class__
(
index
,
value
,
src
.
sparse_size
(),
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
src
.
sparse_size
(),
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
torch_sparse/masked_select.py
View file @
9216364c
...
...
@@ -4,7 +4,7 @@ from torch_sparse.storage import get_layout
def
masked_select
(
src
,
dim
,
mask
):
dim
=
src
.
dim
()
-
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
assert
mask
.
dim
()
==
1
storage
=
src
.
storage
...
...
@@ -25,8 +25,8 @@ def masked_select(src, dim, mask):
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
elif
dim
==
1
:
csr2csc
=
src
.
storage
.
csr2csc
...
...
@@ -48,14 +48,18 @@ def masked_select(src, dim, mask):
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
else
:
idx
=
mask
.
nonzero
().
view
(
-
1
)
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
return
src
.
from_storage
(
storage
)
...
...
@@ -73,7 +77,7 @@ def masked_select_nnz(src, mask, layout=None):
value
=
value
[
mask
]
# There is no other information we can maintain...
storage
=
src
.
storage
.
__class__
(
index
,
value
,
src
.
sparse_size
(),
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
src
.
sparse_size
(),
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
torch_sparse/narrow.py
View file @
9216364c
...
...
@@ -2,7 +2,8 @@ import torch
def
narrow
(
src
,
dim
,
start
,
length
):
dim
=
src
.
dim
()
-
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
start
=
src
.
size
(
dim
)
+
start
if
start
<
0
else
start
if
dim
==
0
:
(
row
,
col
),
value
=
src
.
coo
()
...
...
@@ -25,9 +26,13 @@ def narrow(src, dim, start, length):
value
=
value
.
narrow
(
0
,
row_start
,
row_length
)
sparse_size
=
torch
.
Size
([
length
,
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
rowptr
=
rowptr
,
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
rowptr
=
rowptr
,
is_sorted
=
True
)
elif
dim
==
1
:
# This is faster than accessing `csc()` contrary to the `dim=0` case.
...
...
@@ -50,12 +55,16 @@ def narrow(src, dim, start, length):
value
=
value
[
mask
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
length
])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
colptr
=
colptr
,
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
colptr
=
colptr
,
is_sorted
=
True
)
else
:
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
narrow
(
dim
-
1
,
start
,
length
))
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
narrow
(
dim
-
1
,
start
,
length
))
return
src
.
from_storage
(
storage
)
torch_sparse/tensor.py
View file @
9216364c
...
...
@@ -318,6 +318,48 @@ class SparseTensor(object):
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
csc_matrix
((
value
,
row
,
colptr
),
self
.
size
())
# Standard Operators ######################################################
def
__getitem__
(
self
,
index
):
index
=
list
(
index
)
if
isinstance
(
index
,
tuple
)
else
[
index
]
if
len
([
i
for
i
in
index
if
not
torch
.
is_tensor
(
i
)
and
i
==
...])
>
1
:
raise
SyntaxError
()
dim
=
0
out
=
self
while
len
(
index
)
>
0
:
item
=
index
.
pop
(
0
)
if
isinstance
(
item
,
int
):
out
=
out
.
select
(
dim
,
item
)
dim
+=
1
elif
isinstance
(
item
,
slice
):
if
item
.
step
is
not
None
:
raise
ValueError
(
'Step parameter not yet supported.'
)
start
=
0
if
item
.
start
is
None
else
item
.
start
start
=
self
.
size
(
dim
)
+
start
if
start
<
0
else
start
stop
=
self
.
size
(
dim
)
if
item
.
stop
is
None
else
item
.
stop
stop
=
self
.
size
(
dim
)
+
stop
if
stop
<
0
else
stop
out
=
out
.
narrow
(
dim
,
start
,
max
(
stop
-
start
,
0
))
dim
+=
1
elif
torch
.
is_tensor
(
item
):
if
item
.
dtype
==
torch
.
bool
:
out
=
out
.
masked_select
(
dim
,
item
)
dim
+=
1
elif
item
.
dtype
==
torch
.
long
:
out
=
out
.
index_select
(
dim
,
item
)
dim
+=
1
elif
item
==
Ellipsis
:
if
self
.
dim
()
-
len
(
index
)
<
dim
:
raise
SyntaxError
()
dim
=
self
.
dim
()
-
len
(
index
)
else
:
raise
SyntaxError
()
return
out
# String Reputation #######################################################
def
__repr__
(
self
):
...
...
@@ -457,18 +499,34 @@ if __name__ == '__main__':
dataset
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)
data
=
dataset
[
0
].
to
(
device
)
value
=
torch
.
randn
((
data
.
num_edges
,
),
device
=
device
)
value
=
torch
.
randn
(
data
.
num_edges
,
10
)
mat
=
SparseTensor
(
data
.
edge_index
,
value
)
index
=
torch
.
tensor
([
0
,
1
,
2
])
mask
=
torch
.
zeros
(
data
.
num_nodes
,
dtype
=
torch
.
bool
)
mask
[:
3
]
=
True
mat1
=
SparseTensor
(
data
.
edge_index
,
value
)
print
(
mat
[
1
].
size
())
print
(
mat
[
1
,
1
].
size
())
print
(
mat
[...,
-
1
].
size
())
print
(
mat
[:
10
,
...,
-
1
].
size
())
print
(
mat
[:,
-
1
].
size
())
print
(
mat
[
1
,
:,
-
1
].
size
())
print
(
mat
[
1
:
4
,
1
:
4
].
size
())
print
(
mat
[
index
].
size
())
print
(
mat
[
index
,
index
].
size
())
print
(
mat
[
mask
,
index
].
size
())
# mat[::-1]
# mat[::2]
mat1
=
SparseTensor
.
from_dense
(
mat1
.
to_dense
())
#
mat1 = SparseTensor.from_dense(mat1.to_dense())
print
(
mat1
)
mat
=
SparseTensor
.
from_torch_sparse_coo_tensor
(
mat1
.
to_torch_sparse_coo_tensor
())
#
print(mat1)
#
mat = SparseTensor.from_torch_sparse_coo_tensor(
#
mat1.to_torch_sparse_coo_tensor())
mat
=
SparseTensor
.
from_scipy
(
mat
.
to_scipy
(
layout
=
'csc'
))
print
(
mat
)
#
mat = SparseTensor.from_scipy(mat.to_scipy(layout='csc'))
#
print(mat)
# index = torch.tensor([0, 2])
# mat2 = mat1.index_select(2, index)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment