Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
2ae73b17
Commit
2ae73b17
authored
Jan 25, 2020
by
rusty1s
Browse files
new storage format
parent
fa763bac
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
174 additions
and
159 deletions
+174
-159
torch_sparse/storage.py
torch_sparse/storage.py
+174
-159
No files found.
torch_sparse/storage.py
View file @
2ae73b17
...
...
@@ -68,89 +68,114 @@ def get_layout(layout=None):
class
SparseStorage
(
object
):
cache_keys
=
[
'rowcount'
,
'rowptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
]
cache_keys
=
[
'rowcount'
,
'colptr'
,
'colcount'
,
'csr2csc'
,
'csc2csr'
]
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
rowcount
=
None
,
rowptr
=
None
,
col
count
=
None
,
colptr
=
None
,
c
sr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
def
__init__
(
self
,
row
=
None
,
rowptr
=
None
,
col
=
None
,
value
=
None
,
sparse_size
=
None
,
row
count
=
None
,
colptr
=
None
,
c
olcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
assert
index
.
dtype
==
torch
.
long
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
index
=
index
.
contiguous
()
if
value
is
not
None
:
assert
value
.
device
==
index
.
device
assert
value
.
size
(
0
)
==
index
.
size
(
1
)
value
=
value
.
contiguous
()
assert
row
is
not
None
or
rowptr
is
not
None
assert
col
is
not
None
assert
col
.
dtype
==
torch
.
long
assert
col
.
dim
()
==
1
if
sparse_size
is
None
:
sparse_size
=
torch
.
Size
((
index
.
max
(
dim
=-
1
)[
0
]
+
1
).
tolist
())
M
=
rowptr
.
numel
()
-
1
if
rowptr
is
None
else
row
.
max
().
item
()
+
1
N
=
col
.
max
().
item
()
+
1
sparse_size
=
torch
.
Size
([
M
,
N
])
if
rowcount
is
not
None
:
assert
rowcount
.
dtype
==
torch
.
long
assert
rowcount
.
device
==
index
.
device
assert
rowcount
.
dim
()
==
1
and
rowcount
.
numel
()
==
sparse_size
[
0
]
if
row
is
not
None
:
assert
row
.
dtype
==
torch
.
long
assert
row
.
device
==
col
.
device
assert
row
.
dim
()
==
1
assert
row
.
numel
()
==
col
.
numel
()
if
rowptr
is
not
None
:
assert
rowptr
.
dtype
==
torch
.
long
assert
rowptr
.
device
==
index
.
device
assert
rowptr
.
dim
()
==
1
and
rowptr
.
numel
()
-
1
==
sparse_size
[
0
]
assert
rowptr
.
device
==
col
.
device
assert
rowptr
.
dim
()
==
1
assert
rowptr
.
numel
()
-
1
==
sparse_size
[
0
]
if
colcount
is
not
None
:
assert
colcount
.
dtype
==
torch
.
long
assert
colcount
.
device
==
index
.
device
assert
colcount
.
dim
()
==
1
and
colcount
.
numel
()
==
sparse_size
[
1
]
if
value
is
not
None
:
assert
value
.
device
==
col
.
device
assert
value
.
size
(
0
)
==
col
.
size
(
0
)
value
=
value
.
contiguous
()
if
rowcount
is
not
None
:
assert
rowcount
.
dtype
==
torch
.
long
assert
rowcount
.
device
==
col
.
device
assert
rowcount
.
dim
()
==
1
assert
rowcount
.
numel
()
==
sparse_size
[
0
]
if
colptr
is
not
None
:
assert
colptr
.
dtype
==
torch
.
long
assert
colptr
.
device
==
index
.
device
assert
colptr
.
dim
()
==
1
and
colptr
.
numel
()
-
1
==
sparse_size
[
1
]
assert
colptr
.
device
==
col
.
device
assert
colptr
.
dim
()
==
1
assert
colptr
.
numel
()
-
1
==
sparse_size
[
1
]
if
colcount
is
not
None
:
assert
colcount
.
dtype
==
torch
.
long
assert
colcount
.
device
==
col
.
device
assert
colcount
.
dim
()
==
1
assert
colcount
.
numel
()
==
sparse_size
[
1
]
if
csr2csc
is
not
None
:
assert
csr2csc
.
dtype
==
torch
.
long
assert
csr2csc
.
device
==
index
.
device
assert
csr2csc
.
device
==
col
.
device
assert
csr2csc
.
dim
()
==
1
assert
csr2csc
.
numel
()
==
index
.
size
(
1
)
assert
csr2csc
.
numel
()
==
col
.
size
(
0
)
if
csc2csr
is
not
None
:
assert
csc2csr
.
dtype
==
torch
.
long
assert
csc2csr
.
device
==
index
.
device
assert
csc2csr
.
device
==
col
.
device
assert
csc2csr
.
dim
()
==
1
assert
csc2csr
.
numel
()
==
index
.
size
(
1
)
assert
csc2csr
.
numel
()
==
col
.
size
(
0
)
if
not
is_sorted
:
idx
=
sparse_size
[
1
]
*
index
[
0
]
+
index
[
1
]
# Only sort if necessary...
if
(
idx
<
torch
.
cat
([
idx
.
new_zeros
(
1
),
idx
[:
-
1
]],
dim
=
0
)).
any
():
perm
=
idx
.
argsort
()
index
=
index
[:,
perm
]
value
=
None
if
value
is
None
else
value
[
perm
]
csr2csc
=
None
csc2csr
=
None
self
.
_index
=
index
self
.
_row
=
row
self
.
_rowptr
=
rowptr
self
.
_col
=
col
self
.
_value
=
value
self
.
_sparse_size
=
sparse_size
self
.
_rowcount
=
rowcount
self
.
_rowptr
=
rowptr
self
.
_colcount
=
colcount
self
.
_colptr
=
colptr
self
.
_colcount
=
colcount
self
.
_csr2csc
=
csr2csc
self
.
_csc2csr
=
csc2csr
@
property
def
index
(
self
):
return
self
.
_index
if
not
is_sorted
:
idx
=
self
.
col
.
new_zeros
(
col
.
numel
()
+
1
)
idx
[
1
:]
=
sparse_size
[
1
]
*
self
.
row
+
self
.
col
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
perm
=
idx
.
argsort
()
self
.
_row
=
self
.
row
[
perm
]
self
.
_col
=
self
.
col
[
perm
]
self
.
_value
=
self
.
value
[
perm
]
if
self
.
has_value
()
else
None
self
.
_csr2csc
=
None
self
.
_csc2csr
=
None
def
has_row
(
self
):
return
self
.
_row
is
not
None
@
property
def
row
(
self
):
return
self
.
_index
[
0
]
if
self
.
_row
is
None
:
# TODO
pass
return
self
.
_row
def
has_rowptr
(
self
):
return
self
.
_rowptr
is
not
None
@
property
def
rowptr
(
self
):
if
self
.
_rowptr
is
None
:
func
=
rowptr_cuda
if
self
.
row
.
is_cuda
else
rowptr_cpu
self
.
_rowptr
=
func
.
rowptr
(
self
.
row
,
self
.
sparse_size
[
0
])
return
self
.
_rowptr
@
property
def
col
(
self
):
return
self
.
_
index
[
1
]
return
self
.
_
col
def
has_value
(
self
):
return
self
.
_value
is
not
None
...
...
@@ -159,99 +184,99 @@ class SparseStorage(object):
def
value
(
self
):
return
self
.
_value
def
set_value_
(
self
,
value
,
layout
=
None
):
def
set_value_
(
self
,
value
,
dtype
=
None
,
layout
=
None
):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
value
=
torch
.
full
((
self
.
nnz
(),
),
device
=
self
.
index
.
device
)
value
=
torch
.
full
((
self
.
nnz
(),
),
dtype
=
dtype
,
device
=
self
.
col
.
device
)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc2csr
]
if
torch
.
is_tensor
(
value
):
assert
value
.
device
==
self
.
index
.
device
assert
value
.
size
(
0
)
==
self
.
index
.
size
(
1
)
value
=
value
if
dtype
is
None
else
value
.
to
(
dtype
)
assert
value
.
device
==
self
.
col
.
device
assert
value
.
size
(
0
)
==
self
.
col
.
numel
()
self
.
_value
=
value
return
self
def
set_value
(
self
,
value
,
layout
=
None
):
def
set_value
(
self
,
value
,
dtype
=
None
,
layout
=
None
):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
value
=
torch
.
full
((
self
.
nnz
(),
),
device
=
self
.
index
.
device
)
value
=
torch
.
full
((
self
.
nnz
(),
),
dtype
=
dtype
,
device
=
self
.
col
.
device
)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc2csr
]
if
torch
.
is_tensor
(
value
):
assert
value
.
device
==
self
.
_index
.
device
assert
value
.
size
(
0
)
==
self
.
_index
.
size
(
1
)
return
self
.
__class__
(
self
.
_index
,
value
,
self
.
_sparse_size
,
self
.
_rowcount
,
self
.
_rowptr
,
self
.
_colcount
,
self
.
_colptr
,
self
.
_csr2csc
,
self
.
_csc2csr
,
is_sorted
=
True
,
)
value
=
value
if
dtype
is
None
else
value
.
to
(
dtype
)
assert
value
.
device
==
self
.
col
.
device
assert
value
.
size
(
0
)
==
self
.
col
.
numel
()
def
sparse_size
(
self
,
dim
=
None
):
return
self
.
_sparse_size
if
dim
is
None
else
self
.
_sparse_size
[
dim
]
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
col
,
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
@
property
def
sparse_size
(
self
):
return
self
.
_sparse_size
def
sparse_resize
(
self
,
*
sizes
):
assert
len
(
sizes
)
==
2
old_sizes
,
nnz
=
self
.
sparse_size
(),
self
.
nnz
()
old_sparse_size
,
nnz
=
self
.
sparse_size
,
self
.
nnz
()
diff_0
=
sizes
[
0
]
-
old_size
s
[
0
]
diff_0
=
sizes
[
0
]
-
old_
sparse_
size
[
0
]
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
if
diff_0
>
0
:
if
self
.
has_rowcount
():
rowcount
=
torch
.
cat
([
rowcount
,
rowcount
.
new_zeros
(
diff_0
)])
if
self
.
has_rowptr
():
if
rowptr
is
not
None
:
rowptr
=
torch
.
cat
([
rowptr
,
rowptr
.
new_full
((
diff_0
,
),
nnz
)])
if
rowcount
is
not
None
:
rowcount
=
torch
.
cat
([
rowcount
,
rowcount
.
new_zeros
(
diff_0
)])
else
:
if
self
.
has_rowcount
():
rowcount
=
rowcount
[:
-
diff_0
]
if
self
.
has_rowptr
():
if
rowptr
is
not
None
:
rowptr
=
rowptr
[:
-
diff_0
]
if
rowcount
is
not
None
:
rowcount
=
rowcount
[:
-
diff_0
]
diff_1
=
sizes
[
1
]
-
old_size
s
[
1
]
diff_1
=
sizes
[
1
]
-
old_
sparse_
size
[
1
]
colcount
,
colptr
=
self
.
_colcount
,
self
.
_colptr
if
diff_1
>
0
:
if
self
.
has_colcount
():
colcount
=
torch
.
cat
([
colcount
,
colcount
.
new_zeros
(
diff_1
)])
if
self
.
has_colptr
():
if
colptr
is
not
None
:
colptr
=
torch
.
cat
([
colptr
,
colptr
.
new_full
((
diff_1
,
),
nnz
)])
if
colcount
is
not
None
:
colcount
=
torch
.
cat
([
colcount
,
colcount
.
new_zeros
(
diff_1
)])
else
:
if
self
.
has_colcount
():
colcount
=
colcount
[:
-
diff_1
]
if
self
.
has_colptr
():
if
colptr
is
not
None
:
colptr
=
colptr
[:
-
diff_1
]
if
colcount
is
not
None
:
colcount
=
colcount
[:
-
diff_1
]
return
self
.
__class__
(
self
.
_index
,
self
.
_value
,
sizes
,
rowcount
=
rowcount
,
rowptr
=
rowptr
,
colcount
=
colcount
,
colptr
=
colptr
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
,
)
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
col
,
value
=
self
.
value
,
sparse_size
=
sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
has_rowcount
(
self
):
return
self
.
_rowcount
is
not
None
@
cached_property
def
rowcount
(
self
):
rowptr
=
self
.
rowptr
return
rowptr
[
1
:]
-
rowptr
[:
-
1
]
return
self
.
rowptr
[
1
:]
-
self
.
rowptr
[:
-
1
]
def
has_
row
ptr
(
self
):
return
self
.
_
row
ptr
is
not
None
def
has_
col
ptr
(
self
):
return
self
.
_
col
ptr
is
not
None
@
cached_property
def
rowptr
(
self
):
func
=
rowptr_cuda
if
self
.
index
.
is_cuda
else
rowptr_cpu
return
func
.
rowptr
(
self
.
row
,
self
.
sparse_size
(
0
))
def
colptr
(
self
):
if
self
.
has_csr2csc
():
func
=
rowptr_cuda
if
self
.
col
.
is_cuda
else
rowptr_cpu
return
func
.
rowptr
(
self
.
col
[
self
.
csr2csc
],
self
.
sparse_size
[
1
])
else
:
colptr
=
self
.
col
.
new_zeros
(
self
.
sparse_size
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
return
colptr
def
has_colcount
(
self
):
return
self
.
_colcount
is
not
None
...
...
@@ -259,32 +284,17 @@ class SparseStorage(object):
@
cached_property
def
colcount
(
self
):
if
self
.
has_colptr
():
colptr
=
self
.
colptr
return
colptr
[
1
:]
-
colptr
[:
-
1
]
else
:
col
,
dim_size
=
self
.
col
,
self
.
sparse_size
(
1
)
return
scatter_add
(
torch
.
ones_like
(
col
),
col
,
dim_size
=
dim_size
)
def
has_colptr
(
self
):
return
self
.
_colptr
is
not
None
@
cached_property
def
colptr
(
self
):
if
self
.
has_csr2csc
():
func
=
rowptr_cuda
if
self
.
index
.
is_cuda
else
rowptr_cpu
return
func
.
rowptr
(
self
.
col
[
self
.
csr2csc
],
self
.
sparse_size
(
1
))
return
self
.
colptr
[
1
:]
-
self
.
colptr
[:
-
1
]
else
:
colcount
=
self
.
colcount
colptr
=
colcount
.
new_zeros
(
colcount
.
size
(
0
)
+
1
)
torch
.
cumsum
(
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
return
colptr
return
scatter_add
(
torch
.
ones_like
(
self
.
col
),
self
.
col
,
dim_size
=
self
.
sparse_size
[
1
])
def
has_csr2csc
(
self
):
return
self
.
_csr2csc
is
not
None
@
cached_property
def
csr2csc
(
self
):
idx
=
self
.
_
sparse_size
[
0
]
*
self
.
col
+
self
.
row
idx
=
self
.
sparse_size
[
0
]
*
self
.
col
+
self
.
row
return
idx
.
argsort
()
def
has_csc2csr
(
self
):
...
...
@@ -295,26 +305,29 @@ class SparseStorage(object):
return
self
.
csr2csc
.
argsort
()
def
is_coalesced
(
self
):
idx
=
self
.
sparse_size
(
1
)
*
self
.
row
+
self
.
col
mask
=
idx
>
torch
.
cat
([
idx
.
new_full
((
1
,
),
-
1
),
idx
[:
-
1
]],
dim
=
0
)
return
mask
.
all
().
item
()
idx
=
self
.
col
.
new_zeros
(
self
.
col
.
numel
()
+
1
)
idx
[
1
:]
=
self
.
sparse_size
[
1
]
*
self
.
row
+
self
.
col
return
(
idx
[
1
:]
>
idx
[:
-
1
])
.
all
().
item
()
def
coalesce
(
self
,
reduce
=
'add'
):
idx
=
self
.
sparse_size
(
1
)
*
self
.
row
+
self
.
col
mask
=
idx
>
torch
.
cat
([
idx
.
new_full
((
1
,
),
-
1
),
idx
[:
-
1
]],
dim
=
0
)
idx
=
self
.
col
.
new_zeros
(
self
.
col
.
numel
()
+
1
)
idx
[
1
:]
=
self
.
sparse_size
[
1
]
*
self
.
row
+
self
.
col
mask
=
idx
[
1
:]
>
idx
[:
-
1
]
if
mask
.
all
():
# Skip if indices are already coalesced.
return
self
index
=
self
.
index
[:,
mask
]
row
=
self
.
row
[
mask
]
col
=
self
.
col
[
mask
]
value
=
self
.
value
if
self
.
has_value
():
idx
=
mask
.
cumsum
(
0
)
-
1
idx
=
mask
.
cumsum
(
0
)
.
sub_
(
1
)
value
=
segment_csr
(
idx
,
value
,
reduce
=
reduce
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
self
.
__class__
(
index
,
value
,
self
.
sparse_size
(),
is_sorted
=
True
)
return
self
.
__class__
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_size
=
self
.
sparse_size
,
is_sorted
=
True
)
def
cached_keys
(
self
):
return
[
...
...
@@ -323,7 +336,7 @@ class SparseStorage(object):
]
def
fill_cache_
(
self
,
*
args
):
for
arg
in
args
or
self
.
cache_keys
:
for
arg
in
args
or
self
.
cache_keys
+
[
'row'
,
'rowptr'
]
:
getattr
(
self
,
arg
)
return
self
...
...
@@ -344,46 +357,48 @@ class SparseStorage(object):
return
new_storage
def
apply_value_
(
self
,
func
):
self
.
_value
=
optional
(
func
,
self
.
_
value
)
self
.
_value
=
optional
(
func
,
self
.
value
)
return
self
def
apply_value
(
self
,
func
):
return
self
.
__class__
(
self
.
_index
,
optional
(
func
,
self
.
_value
),
self
.
_sparse_size
,
self
.
_rowcount
,
self
.
_rowptr
,
self
.
_colcount
,
self
.
_colptr
,
self
.
_csr2csc
,
self
.
_csc2csr
,
is_sorted
=
True
,
)
return
self
.
__class__
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
col
,
value
=
optional
(
func
,
self
.
value
),
sparse_size
=
self
.
sparse_size
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
apply_
(
self
,
func
):
self
.
_index
=
func
(
self
.
_index
)
self
.
_value
=
optional
(
func
,
self
.
_value
)
self
.
_row
=
optional
(
func
,
self
.
_row
)
self
.
_rowptr
=
optional
(
func
,
self
.
_rowptr
)
self
.
_col
=
func
(
self
.
col
)
self
.
_value
=
optional
(
func
,
self
.
value
)
for
key
in
self
.
cached_keys
():
setattr
(
self
,
f
'_
{
key
}
'
,
func
(
getattr
(
self
,
f
'_
{
key
}
'
)))
return
self
def
apply
(
self
,
func
):
return
self
.
__class__
(
func
(
self
.
_index
),
optional
(
func
,
self
.
_value
),
self
.
_sparse_size
,
optional
(
func
,
self
.
_rowcount
),
optional
(
func
,
self
.
_rowptr
),
optional
(
func
,
self
.
_colcount
),
optional
(
func
,
self
.
_colptr
),
optional
(
func
,
self
.
_csr2csc
),
optional
(
func
,
self
.
_csc2csr
),
row
=
optional
(
func
,
self
.
_row
),
rowptr
=
optional
(
func
,
self
.
_rowptr
),
col
=
func
(
self
.
col
),
value
=
optional
(
func
,
self
.
value
),
sparse_size
=
self
.
sparse_size
,
rowcount
=
optional
(
func
,
self
.
_rowcount
),
colptr
=
optional
(
func
,
self
.
_colptr
),
colcount
=
optional
(
func
,
self
.
_colcount
),
csr2csc
=
optional
(
func
,
self
.
_csr2csc
),
csc2csr
=
optional
(
func
,
self
.
_csc2csr
),
is_sorted
=
True
,
)
def
map
(
self
,
func
):
data
=
[
func
(
self
.
index
)]
data
=
[]
if
self
.
has_row
():
data
+=
[
func
(
self
.
row
)]
if
self
.
has_rowptr
():
data
+=
[
func
(
self
.
rowptr
)]
data
+=
[
func
(
self
.
col
)]
if
self
.
has_value
():
data
+=
[
func
(
self
.
value
)]
data
+=
[
func
(
getattr
(
self
,
f
'_
{
key
}
'
))
for
key
in
self
.
cached_keys
()]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment