Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
2554bf09
Commit
2554bf09
authored
Dec 19, 2019
by
rusty1s
Browse files
all select methods
parent
4a569c27
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
454 additions
and
180 deletions
+454
-180
cpu/arange_interleave.cpp
cpu/arange_interleave.cpp
+30
-0
setup.py
setup.py
+4
-1
torch_sparse/index_select.py
torch_sparse/index_select.py
+88
-0
torch_sparse/masked_select.py
torch_sparse/masked_select.py
+79
-0
torch_sparse/narrow.py
torch_sparse/narrow.py
+25
-10
torch_sparse/select.py
torch_sparse/select.py
+2
-0
torch_sparse/storage.py
torch_sparse/storage.py
+78
-59
torch_sparse/tensor.py
torch_sparse/tensor.py
+130
-95
torch_sparse/transpose.py
torch_sparse/transpose.py
+18
-15
No files found.
cpu/arange_interleave.cpp
0 → 100644
View file @
2554bf09
#include <torch/extension.h>
#include "compat.h"
at
::
Tensor
arange_interleave
(
at
::
Tensor
start
,
at
::
Tensor
repeat
)
{
auto
count
=
repeat
.
sum
().
DATA_PTR
<
int64_t
>
()[
0
];
auto
out
=
at
::
empty
(
count
,
start
.
options
());
auto
repeat_data
=
repeat
.
DATA_PTR
<
int64_t
>
();
AT_DISPATCH_ALL_TYPES
(
start
.
scalar_type
(),
"arange_interleave"
,
[
&
]
{
auto
start_data
=
start
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
int
i
=
0
;
for
(
int
start_idx
=
0
;
start_idx
<
start
.
size
(
0
);
start_idx
++
)
{
scalar_t
init
=
start_data
[
start_idx
];
for
(
scalar_t
rep_idx
=
0
;
rep_idx
<
repeat_data
[
start_idx
];
rep_idx
++
)
{
out_data
[
i
]
=
init
+
rep_idx
;
i
++
;
}
}
});
return
out
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"arange_interleave"
,
&
arange_interleave
,
"Arange Interleave (CPU)"
);
}
setup.py
View file @
2554bf09
...
...
@@ -12,8 +12,11 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
ext_modules
=
[
CppExtension
(
'torch_sparse.arange_interleave_cpu'
,
[
'cpu/arange_interleave.cpp'
],
extra_compile_args
=
extra_compile_args
),
CppExtension
(
'torch_sparse.spspmm_cpu'
,
[
'cpu/spspmm.cpp'
],
extra_compile_args
=
extra_compile_args
)
extra_compile_args
=
extra_compile_args
)
,
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
...
...
torch_sparse/index_select.py
0 → 100644
View file @
2554bf09
import
torch
from
torch_sparse.storage
import
get_layout
import
torch_sparse.arange_interleave_cpu
as
arange_interleave_cpu
def
__arange_interleave__
(
start
,
repeat
):
assert
start
.
device
==
repeat
.
device
assert
repeat
.
dtype
==
torch
.
long
assert
start
.
dim
()
==
1
assert
repeat
.
dim
()
==
1
assert
start
.
numel
()
==
repeat
.
numel
()
if
start
.
is_cuda
:
raise
NotImplementedError
return
arange_interleave_cpu
.
arange_interleave
(
start
,
repeat
)
def
index_select
(
src
,
dim
,
idx
):
dim
=
src
.
dim
()
-
dim
if
dim
<
0
else
dim
assert
idx
.
dim
()
==
1
idx
=
idx
.
to
(
src
.
device
)
if
dim
==
0
:
(
_
,
col
),
value
=
src
.
coo
()
rowcount
=
src
.
storage
.
rowcount
rowptr
=
src
.
storage
.
rowptr
rowcount
=
rowcount
[
idx
]
tmp
=
torch
.
arange
(
rowcount
.
size
(
0
),
device
=
rowcount
.
device
)
row
=
tmp
.
repeat_interleave
(
rowcount
)
perm
=
__arange_interleave__
(
rowptr
[
idx
],
rowcount
)
col
=
col
[
perm
]
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
src
.
has_value
():
value
=
value
[
perm
]
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
elif
dim
==
1
:
colptr
,
row
,
value
=
src
.
csc
()
colcount
=
src
.
storage
.
colcount
colcount
=
colcount
[
idx
]
tmp
=
torch
.
arange
(
colcount
.
size
(
0
),
device
=
row
.
device
)
col
=
tmp
.
repeat_interleave
(
colcount
)
perm
=
__arange_interleave__
(
colptr
[
idx
],
colcount
)
row
=
row
[
perm
]
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)[:,
csc2csr
]
if
src
.
has_value
():
value
=
value
[
perm
][
csc2csr
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
else
:
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
return
src
.
from_storage
(
storage
)
def
index_select_nnz
(
src
,
idx
,
layout
=
None
):
assert
idx
.
dim
()
==
1
if
get_layout
(
layout
)
==
'csc'
:
idx
=
idx
[
src
.
storage
.
csc2csr
]
index
,
value
=
src
.
coo
()
index
=
index
[:,
idx
]
if
src
.
has_value
():
value
=
value
[
idx
]
# There is no other information we can maintain...
storage
=
src
.
storage
.
__class__
(
index
,
value
,
src
.
sparse_size
(),
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
torch_sparse/masked_select.py
0 → 100644
View file @
2554bf09
import
torch
from
torch_sparse.storage
import
get_layout
def
masked_select
(
src
,
dim
,
mask
):
dim
=
src
.
dim
()
-
dim
if
dim
<
0
else
dim
assert
mask
.
dim
()
==
1
storage
=
src
.
storage
if
dim
==
0
:
(
row
,
col
),
value
=
src
.
coo
()
rowcount
=
src
.
storage
.
rowcount
row_mask
=
mask
[
row
]
rowcount
=
rowcount
[
mask
]
idx
=
torch
.
arange
(
rowcount
.
size
(
0
),
device
=
rowcount
.
device
)
row
=
idx
.
repeat_interleave
(
rowcount
)
col
=
col
[
row_mask
]
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
src
.
has_value
():
value
=
value
[
row_mask
]
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
elif
dim
==
1
:
csr2csc
=
src
.
storage
.
csr2csc
row
=
src
.
storage
.
row
[
csr2csc
]
col
=
src
.
storage
.
col
[
csr2csc
]
colcount
=
src
.
storage
.
colcount
col_mask
=
mask
[
col
]
colcount
=
colcount
[
mask
]
tmp
=
torch
.
arange
(
colcount
.
size
(
0
),
device
=
row
.
device
)
col
=
tmp
.
repeat_interleave
(
colcount
)
row
=
row
[
col_mask
]
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)[:,
csc2csr
]
value
=
src
.
storage
.
value
if
src
.
has_value
():
value
=
value
[
csr2csc
][
col_mask
][
csc2csr
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
else
:
idx
=
mask
.
nonzero
().
view
(
-
1
)
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
return
src
.
from_storage
(
storage
)
def
masked_select_nnz
(
src
,
mask
,
layout
=
None
):
assert
mask
.
dim
()
==
1
if
get_layout
(
layout
)
==
'csc'
:
mask
=
mask
[
src
.
storage
.
csc2csr
]
index
,
value
=
src
.
coo
()
index
=
index
[:,
mask
]
if
src
.
has_value
():
value
=
value
[
mask
]
# There is no other information we can maintain...
storage
=
src
.
storage
.
__class__
(
index
,
value
,
src
.
sparse_size
(),
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
torch_sparse/narrow.py
View file @
2554bf09
...
...
@@ -2,9 +2,16 @@ import torch
def
narrow
(
src
,
dim
,
start
,
length
):
dim
=
src
.
dim
()
-
dim
if
dim
<
0
else
dim
if
dim
==
0
:
(
row
,
col
),
value
=
src
.
coo
()
rowptr
,
_
,
_
=
src
.
csr
()
rowptr
=
src
.
storage
.
rowptr
# Maintain `rowcount`...
rowcount
=
src
.
storage
.
_rowcount
if
rowcount
is
not
None
:
rowcount
=
rowcount
.
narrow
(
0
,
start
=
start
,
length
=
length
)
rowptr
=
rowptr
.
narrow
(
0
,
start
=
start
,
length
=
length
+
1
)
row_start
=
rowptr
[
0
]
...
...
@@ -18,15 +25,22 @@ def narrow(src, dim, start, length):
value
=
value
.
narrow
(
0
,
row_start
,
row_length
)
sparse_size
=
torch
.
Size
([
length
,
src
.
sparse_size
(
1
)])
storage
=
src
.
_storage
.
__class__
(
index
,
value
,
sparse_size
,
rowptr
=
rowptr
,
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
rowptr
=
rowptr
,
is_sorted
=
True
)
elif
dim
==
1
:
# This is faster than accessing `csc()`
in analog
y to the `dim=0` case.
# This is faster than accessing `csc()`
contrar
y to the `dim=0` case.
(
row
,
col
),
value
=
src
.
coo
()
mask
=
(
col
>=
start
)
&
(
col
<
start
+
length
)
colptr
=
src
.
_storage
.
_colptr
# Maintain `colcount`...
colcount
=
src
.
storage
.
_colcount
if
colcount
is
not
None
:
colcount
=
colcount
.
narrow
(
0
,
start
=
start
,
length
=
length
)
# Maintain `colptr`...
colptr
=
src
.
storage
.
_colptr
if
colptr
is
not
None
:
colptr
=
colptr
.
narrow
(
0
,
start
=
start
,
length
=
length
+
1
)
colptr
=
colptr
-
colptr
[
0
]
...
...
@@ -36,11 +50,12 @@ def narrow(src, dim, start, length):
value
=
value
[
mask
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
length
])
storage
=
src
.
_storage
.
__class__
(
index
,
value
,
sparse_size
,
colptr
=
colptr
,
is_sorted
=
True
)
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
colcount
=
colcount
,
colptr
=
colptr
,
is_sorted
=
True
)
else
:
storage
=
src
.
_
storage
.
apply_value
(
lambda
x
:
x
.
narrow
(
dim
-
1
,
start
,
length
))
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
narrow
(
dim
-
1
,
start
,
length
))
return
src
.
__class__
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
torch_sparse/select.py
0 → 100644
View file @
2554bf09
def
select
(
src
,
dim
,
idx
):
return
src
.
narrow
(
dim
,
start
=
idx
,
length
=
1
)
torch_sparse/storage.py
View file @
2554bf09
...
...
@@ -20,19 +20,26 @@ class cached_property(object):
return
value
layouts
=
[
'coo'
,
'csr'
,
'csc'
]
def
get_layout
(
layout
=
None
):
if
layout
is
None
:
layout
=
'coo'
warnings
.
warn
(
'`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.'
)
assert
layout
in
layouts
return
layout
class
SparseStorage
(
object
):
layouts
=
[
'coo'
,
'csr'
,
'csc'
]
cache_keys
=
[
'rowptr'
,
'colptr'
,
'csr_to_csc'
,
'csc_to_csr'
]
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
rowptr
=
None
,
colptr
=
None
,
csr_to_csc
=
None
,
csc_to_csr
=
None
,
is_sorted
=
False
):
cache_keys
=
[
'rowcount'
,
'rowptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
]
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
rowcount
=
None
,
rowptr
=
None
,
colcount
=
None
,
colptr
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
assert
index
.
dtype
==
torch
.
long
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
...
...
@@ -46,25 +53,37 @@ class SparseStorage(object):
if
sparse_size
is
None
:
sparse_size
=
torch
.
Size
((
index
.
max
(
dim
=-
1
)[
0
]
+
1
).
tolist
())
if
rowcount
is
not
None
:
assert
rowcount
.
dtype
==
torch
.
long
assert
rowcount
.
device
==
index
.
device
assert
rowcount
.
dim
()
==
1
and
rowcount
.
numel
()
==
sparse_size
[
0
]
if
rowptr
is
not
None
:
assert
rowptr
.
dtype
==
torch
.
long
and
rowptr
.
device
==
index
.
device
assert
rowptr
.
dtype
==
torch
.
long
assert
rowptr
.
device
==
index
.
device
assert
rowptr
.
dim
()
==
1
and
rowptr
.
numel
()
-
1
==
sparse_size
[
0
]
if
colcount
is
not
None
:
assert
colcount
.
dtype
==
torch
.
long
assert
colcount
.
device
==
index
.
device
assert
colcount
.
dim
()
==
1
and
colcount
.
numel
()
==
sparse_size
[
1
]
if
colptr
is
not
None
:
assert
colptr
.
dtype
==
torch
.
long
and
colptr
.
device
==
index
.
device
assert
colptr
.
dtype
==
torch
.
long
assert
colptr
.
device
==
index
.
device
assert
colptr
.
dim
()
==
1
and
colptr
.
numel
()
-
1
==
sparse_size
[
1
]
if
csr
_to_
csc
is
not
None
:
assert
csr
_to_
csc
.
dtype
==
torch
.
long
assert
csr
_to_
csc
.
device
==
index
.
device
assert
csr
_to_
csc
.
dim
()
==
1
assert
csr
_to_
csc
.
numel
()
==
index
.
size
(
1
)
if
csr
2
csc
is
not
None
:
assert
csr
2
csc
.
dtype
==
torch
.
long
assert
csr
2
csc
.
device
==
index
.
device
assert
csr
2
csc
.
dim
()
==
1
assert
csr
2
csc
.
numel
()
==
index
.
size
(
1
)
if
csc
_to_
csr
is
not
None
:
assert
csc
_to_
csr
.
dtype
==
torch
.
long
assert
csc
_to_
csr
.
device
==
index
.
device
assert
csc
_to_
csr
.
dim
()
==
1
assert
csc
_to_
csr
.
numel
()
==
index
.
size
(
1
)
if
csc
2
csr
is
not
None
:
assert
csc
2
csr
.
dtype
==
torch
.
long
assert
csc
2
csr
.
device
==
index
.
device
assert
csc
2
csr
.
dim
()
==
1
assert
csc
2
csr
.
numel
()
==
index
.
size
(
1
)
if
not
is_sorted
:
idx
=
sparse_size
[
1
]
*
index
[
0
]
+
index
[
1
]
...
...
@@ -73,18 +92,18 @@ class SparseStorage(object):
perm
=
idx
.
argsort
()
index
=
index
[:,
perm
]
value
=
None
if
value
is
None
else
value
[
perm
]
rowptr
=
None
colptr
=
None
csr_to_csc
=
None
csc_to_csr
=
None
csr2csc
=
None
csc2csr
=
None
self
.
_index
=
index
self
.
_value
=
value
self
.
_sparse_size
=
sparse_size
self
.
_rowcount
=
rowcount
self
.
_rowptr
=
rowptr
self
.
_colcount
=
colcount
self
.
_colptr
=
colptr
self
.
_csr
_to_
csc
=
csr
_to_
csc
self
.
_csc
_to_
csr
=
csc
_to_
csr
self
.
_csr
2
csc
=
csr
2
csc
self
.
_csc
2
csr
=
csc
2
csr
@
property
def
index
(
self
):
...
...
@@ -106,27 +125,17 @@ class SparseStorage(object):
return
self
.
_value
def
set_value_
(
self
,
value
,
layout
=
None
):
if
layout
is
None
:
layout
=
'coo'
warnings
.
warn
(
'`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.'
)
assert
layout
in
self
.
layouts
assert
value
.
device
==
self
.
_index
.
device
assert
value
.
size
(
0
)
==
self
.
_index
.
size
(
1
)
if
value
is
not
None
and
layout
==
'csc'
:
value
=
value
[
self
.
csc
_to_
csr
]
if
value
is
not
None
and
get_
layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc
2
csr
]
return
self
.
apply_value_
(
lambda
x
:
value
)
def
set_value
(
self
,
value
,
layout
=
None
):
if
layout
is
None
:
layout
=
'coo'
warnings
.
warn
(
'`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.'
)
assert
layout
in
self
.
layouts
assert
value
.
device
==
self
.
_index
.
device
assert
value
.
size
(
0
)
==
self
.
_index
.
size
(
1
)
if
value
is
not
None
and
layout
==
'csc'
:
value
=
value
[
self
.
csc
_to_
csr
]
if
value
is
not
None
and
get_
layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc
2
csr
]
return
self
.
apply_value
(
lambda
x
:
value
)
def
sparse_size
(
self
,
dim
=
None
):
...
...
@@ -137,28 +146,34 @@ class SparseStorage(object):
self
.
_sparse_size
==
sizes
return
self
@
cached_property
def
rowcount
(
self
):
one
=
torch
.
ones_like
(
self
.
row
)
return
segment_add
(
one
,
self
.
row
,
dim
=
0
,
dim_size
=
self
.
_sparse_size
[
0
])
@
cached_property
def
rowptr
(
self
):
row
=
self
.
row
ones
=
torch
.
ones_like
(
row
)
out_deg
=
segment_add
(
ones
,
row
,
dim
=
0
,
dim_size
=
self
.
_sparse_size
[
0
])
return
torch
.
cat
([
row
.
new_zeros
(
1
),
out_deg
.
cumsum
(
0
)],
dim
=
0
)
rowcount
=
self
.
rowcount
return
torch
.
cat
([
rowcount
.
new_zeros
(
1
),
rowcount
.
cumsum
(
0
)],
dim
=
0
)
@
cached_property
def
colcount
(
self
):
one
=
torch
.
ones_like
(
self
.
col
)
return
scatter_add
(
one
,
self
.
col
,
dim
=
0
,
dim_size
=
self
.
_sparse_size
[
1
])
@
cached_property
def
colptr
(
self
):
col
=
self
.
col
ones
=
torch
.
ones_like
(
col
)
in_deg
=
scatter_add
(
ones
,
col
,
dim
=
0
,
dim_size
=
self
.
_sparse_size
[
1
])
return
torch
.
cat
([
col
.
new_zeros
(
1
),
in_deg
.
cumsum
(
0
)],
dim
=
0
)
colcount
=
self
.
colcount
return
torch
.
cat
([
colcount
.
new_zeros
(
1
),
colcount
.
cumsum
(
0
)],
dim
=
0
)
@
cached_property
def
csr
_to_
csc
(
self
):
def
csr
2
csc
(
self
):
idx
=
self
.
_sparse_size
[
0
]
*
self
.
col
+
self
.
row
return
idx
.
argsort
()
@
cached_property
def
csc
_to_
csr
(
self
):
return
self
.
csr
_to_
csc
.
argsort
()
def
csc
2
csr
(
self
):
return
self
.
csr
2
csc
.
argsort
()
def
is_coalesced
(
self
):
raise
NotImplementedError
...
...
@@ -202,10 +217,12 @@ class SparseStorage(object):
self
.
_index
,
optional
(
func
,
self
.
_value
),
self
.
_sparse_size
,
self
.
_rowcount
,
self
.
_rowptr
,
self
.
_colcount
,
self
.
_colptr
,
self
.
_csr
_to_
csc
,
self
.
_csc
_to_
csr
,
self
.
_csr
2
csc
,
self
.
_csc
2
csr
,
is_sorted
=
True
,
)
...
...
@@ -221,10 +238,12 @@ class SparseStorage(object):
func
(
self
.
_index
),
optional
(
func
,
self
.
_value
),
self
.
_sparse_size
,
optional
(
func
,
self
.
_rowcount
),
optional
(
func
,
self
.
_rowptr
),
optional
(
func
,
self
.
_colcount
),
optional
(
func
,
self
.
_colptr
),
optional
(
func
,
self
.
_csr
_to_
csc
),
optional
(
func
,
self
.
_csc
_to_
csr
),
optional
(
func
,
self
.
_csr
2
csc
),
optional
(
func
,
self
.
_csc
2
csr
),
is_sorted
=
True
,
)
...
...
torch_sparse/tensor.py
View file @
2554bf09
...
...
@@ -3,21 +3,24 @@ from textwrap import indent
import
torch
import
scipy.sparse
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
,
get_layout
from
torch_sparse.transpose
import
t
from
torch_sparse.narrow
import
narrow
from
torch_sparse.select
import
select
from
torch_sparse.index_select
import
index_select
,
index_select_nnz
from
torch_sparse.masked_select
import
masked_select
,
masked_select_nnz
class
SparseTensor
(
object
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
self
.
_
storage
=
SparseStorage
(
index
,
value
,
sparse_size
,
is_sorted
=
is_sorted
)
self
.
storage
=
SparseStorage
(
index
,
value
,
sparse_size
,
is_sorted
=
is_sorted
)
@
classmethod
def
from_storage
(
self
,
storage
):
self
=
SparseTensor
.
__new__
(
SparseTensor
)
self
.
_
storage
=
storage
self
.
storage
=
storage
return
self
@
classmethod
...
...
@@ -32,10 +35,10 @@ class SparseTensor(object):
return
self
.
__class__
(
index
,
value
,
mat
.
size
()[:
2
],
is_sorted
=
True
)
def
__copy__
(
self
):
return
self
.
__class__
.
from_storage
(
self
.
_
storage
)
return
self
.
from_storage
(
self
.
storage
)
def
clone
(
self
):
return
self
.
__class__
.
from_storage
(
self
.
_
storage
.
clone
())
return
self
.
from_storage
(
self
.
storage
.
clone
())
def
__deepcopy__
(
self
,
memo
):
new_sparse_tensor
=
self
.
clone
()
...
...
@@ -45,58 +48,57 @@ class SparseTensor(object):
# Formats #################################################################
def
coo
(
self
):
return
self
.
_
storage
.
index
,
self
.
_
storage
.
value
return
self
.
storage
.
index
,
self
.
storage
.
value
def
csr
(
self
):
return
self
.
_
storage
.
rowptr
,
self
.
_
storage
.
col
,
self
.
_
storage
.
value
return
self
.
storage
.
rowptr
,
self
.
storage
.
col
,
self
.
storage
.
value
def
csc
(
self
):
perm
=
self
.
_
storage
.
csr
_to_
csc
return
(
self
.
_
storage
.
colptr
,
self
.
_
storage
.
row
[
perm
],
self
.
_
storage
.
value
[
perm
]
if
self
.
has_value
()
else
None
)
perm
=
self
.
storage
.
csr
2
csc
return
(
self
.
storage
.
colptr
,
self
.
storage
.
row
[
perm
],
self
.
storage
.
value
[
perm
]
if
self
.
has_value
()
else
None
)
# Storage inheritance #####################################################
def
has_value
(
self
):
return
self
.
_
storage
.
has_value
()
return
self
.
storage
.
has_value
()
def
set_value_
(
self
,
value
,
layout
=
None
):
self
.
_
storage
.
set_value_
(
value
,
layout
)
self
.
storage
.
set_value_
(
value
,
layout
)
return
self
def
set_value
(
self
,
value
,
layout
=
None
):
storage
=
self
.
_storage
.
set_value
(
value
,
layout
)
return
self
.
__class__
.
from_storage
(
storage
)
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
))
def
sparse_size
(
self
,
dim
=
None
):
return
self
.
_
storage
.
sparse_size
(
dim
)
return
self
.
storage
.
sparse_size
(
dim
)
def
sparse_resize_
(
self
,
*
sizes
):
self
.
_
storage
.
sparse_resize_
(
*
sizes
)
self
.
storage
.
sparse_resize_
(
*
sizes
)
return
self
def
is_coalesced
(
self
):
return
self
.
_
storage
.
is_coalesced
()
return
self
.
storage
.
is_coalesced
()
def
coalesce
(
self
):
return
self
.
__class__
.
from_storage
(
self
.
_
storage
.
coalesce
())
return
self
.
from_storage
(
self
.
storage
.
coalesce
())
def
cached_keys
(
self
):
return
self
.
_
storage
.
cached_keys
()
return
self
.
storage
.
cached_keys
()
def
fill_cache_
(
self
,
*
args
):
self
.
_
storage
.
fill_cache_
(
*
args
)
self
.
storage
.
fill_cache_
(
*
args
)
return
self
def
clear_cache_
(
self
,
*
args
):
self
.
_
storage
.
clear_cache_
(
*
args
)
self
.
storage
.
clear_cache_
(
*
args
)
return
self
# Utility functions #######################################################
def
size
(
self
,
dim
=
None
):
size
=
self
.
sparse_size
()
size
+=
self
.
_
storage
.
value
.
size
()[
1
:]
if
self
.
has_value
()
else
()
size
+=
self
.
storage
.
value
.
size
()[
1
:]
if
self
.
has_value
()
else
()
return
size
if
dim
is
None
else
size
[
dim
]
def
dim
(
self
):
...
...
@@ -107,7 +109,7 @@ class SparseTensor(object):
return
self
.
size
()
def
nnz
(
self
):
return
self
.
_
storage
.
index
.
size
(
1
)
return
self
.
storage
.
index
.
size
(
1
)
def
density
(
self
):
return
self
.
nnz
()
/
(
self
.
sparse_size
(
0
)
*
self
.
sparse_size
(
1
))
...
...
@@ -138,50 +140,47 @@ class SparseTensor(object):
return
index_sym
.
item
()
and
value_sym
def
detach_
(
self
):
self
.
_
storage
.
apply_
(
lambda
x
:
x
.
detach_
())
self
.
storage
.
apply_
(
lambda
x
:
x
.
detach_
())
return
self
def
detach
(
self
):
storage
=
self
.
_storage
.
apply
(
lambda
x
:
x
.
detach
())
return
self
.
__class__
.
from_storage
(
storage
)
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
detach
()))
def
pin_memory
(
self
):
storage
=
self
.
_storage
.
apply
(
lambda
x
:
x
.
pin_memory
())
return
self
.
__class__
.
from_storage
(
storage
)
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
pin_memory
()))
def
is_pinned
(
self
):
return
all
(
self
.
_
storage
.
map
(
lambda
x
:
x
.
is_pinned
()))
return
all
(
self
.
storage
.
map
(
lambda
x
:
x
.
is_pinned
()))
def
share_memory_
(
self
):
self
.
_
storage
.
apply_
(
lambda
x
:
x
.
share_memory_
())
self
.
storage
.
apply_
(
lambda
x
:
x
.
share_memory_
())
return
self
def
is_shared
(
self
):
return
all
(
self
.
_
storage
.
map
(
lambda
x
:
x
.
is_shared
()))
return
all
(
self
.
storage
.
map
(
lambda
x
:
x
.
is_shared
()))
@
property
def
device
(
self
):
return
self
.
_
storage
.
index
.
device
return
self
.
storage
.
index
.
device
def
cpu
(
self
):
storage
=
self
.
_storage
.
apply
(
lambda
x
:
x
.
cpu
())
return
self
.
__class__
.
from_storage
(
storage
)
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
cpu
()))
def
cuda
(
self
,
device
=
None
,
non_blocking
=
False
,
**
kwargs
):
storage
=
self
.
_
storage
.
apply
(
lambda
x
:
x
.
cuda
(
device
,
non_blocking
,
**
kwargs
))
return
self
.
__class__
.
from_storage
(
storage
)
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
cuda
(
device
,
non_blocking
,
**
kwargs
))
return
self
.
from_storage
(
storage
)
@
property
def
is_cuda
(
self
):
return
self
.
_
storage
.
index
.
is_cuda
return
self
.
storage
.
index
.
is_cuda
@
property
def
dtype
(
self
):
return
self
.
_
storage
.
value
.
dtype
if
self
.
has_value
()
else
None
return
self
.
storage
.
value
.
dtype
if
self
.
has_value
()
else
None
def
is_floating_point
(
self
):
value
=
self
.
_
storage
.
value
value
=
self
.
storage
.
value
return
self
.
has_value
()
and
torch
.
is_floating_point
(
value
)
def
type
(
self
,
dtype
=
None
,
non_blocking
=
False
,
**
kwargs
):
...
...
@@ -191,9 +190,10 @@ class SparseTensor(object):
if
dtype
==
self
.
dtype
:
return
self
storage
=
self
.
_storage
.
apply_value
(
lambda
x
:
x
.
type
(
dtype
,
non_blocking
,
**
kwargs
))
return
self
.
__class__
.
from_storage
(
storage
)
storage
=
self
.
storage
.
apply_value
(
lambda
x
:
x
.
type
(
dtype
,
non_blocking
,
**
kwargs
))
return
self
.
from_storage
(
storage
)
def
to
(
self
,
*
args
,
**
kwargs
):
storage
=
None
...
...
@@ -201,17 +201,17 @@ class SparseTensor(object):
if
'device'
in
kwargs
:
device
=
kwargs
[
'device'
]
del
kwargs
[
'device'
]
storage
=
self
.
_
storage
.
apply
(
lambda
x
:
x
.
to
(
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
getattr
(
kwargs
,
'non_blocking'
,
False
)))
for
arg
in
args
[:]:
if
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
torch
.
device
):
storage
=
self
.
_
storage
.
apply
(
lambda
x
:
x
.
to
(
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
to
(
arg
,
non_blocking
=
getattr
(
kwargs
,
'non_blocking'
,
False
)))
args
.
remove
(
arg
)
if
storage
is
not
None
:
self
=
self
.
__class__
.
from_storage
(
storage
)
self
=
self
.
from_storage
(
storage
)
if
len
(
args
)
>
0
or
len
(
kwargs
)
>
0
:
self
=
self
.
type
(
*
args
,
**
kwargs
)
...
...
@@ -260,16 +260,13 @@ class SparseTensor(object):
def
to_torch_sparse_coo_tensor
(
self
,
dtype
=
None
,
requires_grad
=
False
):
index
,
value
=
self
.
coo
()
return
torch
.
sparse_coo_tensor
(
index
,
value
if
self
.
has_value
()
else
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
),
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
'coo'
):
index
,
value
if
self
.
has_value
()
else
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
),
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
None
):
assert
self
.
dim
()
==
2
assert
layout
in
self
.
_storage
.
layout
s
layout
=
get_layout
(
layout
)
if
not
self
.
has_value
():
ones
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
).
numpy
()
...
...
@@ -318,33 +315,20 @@ class SparseTensor(object):
SparseTensor
.
t
=
t
SparseTensor
.
narrow
=
narrow
# def set_diag(self, value):
# raise NotImplementedError
# def masked_select(self, mask):
# raise NotImplementedError
# def index_select(self, index):
# raise NotImplementedError
# def select(self, dim, index):
# raise NotImplementedError
# def filter(self, index):
# assert self.is_symmetric
# assert index.dtype == torch.long or index.dtype == torch.bool
# raise NotImplementedError
# def permute(self, index):
# assert index.dtype == torch.long
# return self.filter(index)
SparseTensor
.
select
=
select
SparseTensor
.
index_select
=
index_select
SparseTensor
.
index_select_nnz
=
index_select_nnz
SparseTensor
.
masked_select
=
masked_select
SparseTensor
.
masked_select_nnz
=
masked_select_nnz
# def __getitem__(self, idx):
# # Convert int and slice to index tensor
# # Filter list into edge and sparse slice
# raise NotImplementedError
# def set_diag(self, value):
# raise NotImplementedError
# def __reduce(self, dim, reduce, only_nnz):
# raise NotImplementedError
...
...
@@ -388,7 +372,7 @@ SparseTensor.narrow = narrow
# '"coo". This may lead to unexpected behaviour.')
# assert layout in ['coo', 'csr', 'csc']
# if layout == 'csc':
# other = other[self._arg_csc
_to_
csr]
# other = other[self._arg_csc
2
csr]
# if self.has_value:
# return self.set_value(self._value + other, 'coo')
# else:
...
...
@@ -440,10 +424,61 @@ if __name__ == '__main__':
dataset
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)
data
=
dataset
[
0
].
to
(
device
)
value
=
torch
.
randn
((
data
.
num_edges
,
),
device
=
device
)
value
=
torch
.
randn
((
data
.
num_edges
,
10
),
device
=
device
)
mat1
=
SparseTensor
(
data
.
edge_index
,
value
)
index
=
torch
.
tensor
([
0
,
2
])
mat2
=
mat1
.
index_select
(
2
,
index
)
index
=
torch
.
randperm
(
data
.
num_nodes
)[:
data
.
num_nodes
-
500
]
mask
=
torch
.
zeros
(
data
.
num_nodes
,
dtype
=
torch
.
bool
)
mask
[
index
]
=
True
t
=
time
.
perf_counter
()
for
_
in
range
(
1000
):
mat2
=
mat1
.
index_select
(
0
,
index
)
print
(
time
.
perf_counter
()
-
t
)
t
=
time
.
perf_counter
()
for
_
in
range
(
1000
):
mat2
=
mat1
.
masked_select
(
0
,
mask
)
print
(
time
.
perf_counter
()
-
t
)
# mat2 = mat1.narrow(1, start=0, length=3)
# print(mat2)
# index = torch.randperm(data.num_nodes)
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.index_select(0, index)
# print(time.perf_counter() - t)
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.index_select(1, index)
# print(time.perf_counter() - t)
# raise NotImplementedError
# t = time.perf_counter()
# for _ in range(1000):
# mat2 = mat1.t().index_select(0, index).t()
# print(time.perf_counter() - t)
# print(mat1)
# mat1.index_select((0, 1), torch.tensor([0, 1, 2, 3, 4]))
# print(mat3)
# print(mat3.storage.rowcount)
# print(mat1)
# (row, col), value = mat1.coo()
# mask = row < 3
# t = time.perf_counter()
# for _ in range(10000):
# mat2 = mat1.narrow(1, start=10, length=2690)
# print(time.perf_counter() - t)
# # print(mat1.to_dense().size())
# print(mat1.to_torch_sparse_coo_tensor().to_dense().size())
...
...
@@ -461,24 +496,24 @@ if __name__ == '__main__':
# print(mat1.cached_keys())
# print('-------- NARROW ----------')
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out
=
mat1
.
narrow
(
dim
=
0
,
start
=
10
,
length
=
10
)
# out.
_
storage.colptr
print
(
time
.
perf_counter
()
-
t
)
print
(
out
)
print
(
out
.
cached_keys
())
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out
=
mat1
.
narrow
(
dim
=
1
,
start
=
10
,
length
=
2000
)
# out.
_
storage.colptr
print
(
time
.
perf_counter
()
-
t
)
print
(
out
)
print
(
out
.
cached_keys
())
#
t = time.perf_counter()
#
for _ in range(100):
#
out = mat1.narrow(dim=0, start=10, length=10)
#
# out.storage.colptr
#
print(time.perf_counter() - t)
#
print(out)
#
print(out.cached_keys())
#
t = time.perf_counter()
#
for _ in range(100):
#
out = mat1.narrow(dim=1, start=10, length=2000)
#
# out.storage.colptr
#
print(time.perf_counter() - t)
#
print(out)
#
print(out.cached_keys())
# mat1 = mat1.narrow(0, start=10, length=10)
# mat1.
_
storage._value = torch.randn(mat1.nnz(), 20)
# mat1.storage._value = torch.randn(mat1.nnz(), 20)
# print(mat1.coo()[1].size())
# mat1 = mat1.narrow(2, start=10, length=10)
# print(mat1.coo()[1].size())
...
...
torch_sparse/transpose.py
View file @
2554bf09
...
...
@@ -28,18 +28,21 @@ def transpose(index, value, m, n, coalesced=True):
return
index
,
value
def
t
(
mat
):
(
row
,
col
),
value
=
mat
.
coo
()
csr_to_csc
=
mat
.
_storage
.
csr_to_csc
storage
=
mat
.
_storage
.
__class__
(
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)[:,
csr_to_csc
],
value
=
value
[
csr_to_csc
]
if
mat
.
has_value
()
else
None
,
sparse_size
=
mat
.
sparse_size
()[::
-
1
],
rowptr
=
mat
.
_storage
.
_colptr
,
colptr
=
mat
.
_storage
.
_rowptr
,
csr_to_csc
=
mat
.
_storage
.
_csc_to_csr
,
csc_to_csr
=
csr_to_csc
,
is_sorted
=
True
)
return
mat
.
__class__
.
from_storage
(
storage
)
def
t
(
src
):
(
row
,
col
),
value
=
src
.
coo
()
csr2csc
=
src
.
storage
.
csr2csc
storage
=
src
.
storage
.
__class__
(
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)[:,
csr2csc
],
value
=
value
[
csr2csc
]
if
src
.
has_value
()
else
None
,
sparse_size
=
src
.
sparse_size
()[::
-
1
],
rowcount
=
src
.
storage
.
_colcount
,
rowptr
=
src
.
storage
.
_colptr
,
colcount
=
src
.
storage
.
_rowcount
,
colptr
=
src
.
storage
.
_rowptr
,
csr2csc
=
src
.
storage
.
_csc2csr
,
csc2csr
=
csr2csc
,
is_sorted
=
True
,
)
return
src
.
from_storage
(
storage
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment