Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
e61e3d45
Commit
e61e3d45
authored
Dec 20, 2019
by
rusty1s
Browse files
coalesce
parent
7517c965
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
52 deletions
+63
-52
torch_sparse/storage.py
torch_sparse/storage.py
+24
-15
torch_sparse/tensor.py
torch_sparse/tensor.py
+39
-37
No files found.
torch_sparse/storage.py
View file @
e61e3d45
import
warnings
import
torch
import
torch_scatter
from
torch_scatter
import
scatter_add
,
segment_add
...
...
@@ -37,17 +38,9 @@ class SparseStorage(object):
'rowcount'
,
'rowptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
]
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
rowcount
=
None
,
rowptr
=
None
,
colcount
=
None
,
colptr
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
rowcount
=
None
,
rowptr
=
None
,
colcount
=
None
,
colptr
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
assert
index
.
dtype
==
torch
.
long
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
...
...
@@ -185,11 +178,27 @@ class SparseStorage(object):
def
is_coalesced
(
self
):
idx
=
self
.
sparse_size
(
1
)
*
self
.
row
+
self
.
col
mask
=
idx
==
torch
.
cat
([
idx
.
new_full
((
1
,
),
-
1
),
idx
[:
-
1
]],
dim
=
0
)
return
not
mask
.
a
ny
().
item
()
mask
=
idx
>
torch
.
cat
([
idx
.
new_full
((
1
,
),
-
1
),
idx
[:
-
1
]],
dim
=
0
)
return
mask
.
a
ll
().
item
()
def
coalesce
(
self
):
raise
NotImplementedError
def
coalesce
(
self
,
reduce
=
'add'
):
idx
=
self
.
sparse_size
(
1
)
*
self
.
row
+
self
.
col
mask
=
idx
>
torch
.
cat
([
idx
.
new_full
((
1
,
),
-
1
),
idx
[:
-
1
]],
dim
=
0
)
if
mask
.
all
():
# Already coalesced
return
self
index
=
self
.
index
[:,
mask
]
value
=
self
.
value
if
self
.
has_value
():
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
idx
=
mask
.
cumsum
(
0
)
-
1
op
=
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)
value
=
op
(
value
,
idx
,
dim
=
0
,
dim_size
=
idx
[
-
1
].
item
()
+
1
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
self
.
__class__
(
index
,
value
,
self
.
sparse_size
(),
is_sorted
=
True
)
def
cached_keys
(
self
):
return
[
...
...
torch_sparse/tensor.py
View file @
e61e3d45
...
...
@@ -14,8 +14,8 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz
class
SparseTensor
(
object
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
self
.
storage
=
SparseStorage
(
index
,
value
,
sparse_size
,
is_sorted
=
is_sorted
)
self
.
storage
=
SparseStorage
(
index
,
value
,
sparse_size
,
is_sorted
=
is_sorted
)
@
classmethod
def
from_storage
(
self
,
storage
):
...
...
@@ -36,8 +36,8 @@ class SparseTensor(object):
@
classmethod
def
from_torch_sparse_coo_tensor
(
self
,
mat
,
is_sorted
=
False
):
return
SparseTensor
(
mat
.
_indices
(),
mat
.
_values
(),
mat
.
size
()[:
2
],
is_sorted
=
is_sorted
)
return
SparseTensor
(
mat
.
_indices
(),
mat
.
_values
(),
mat
.
size
()[:
2
],
is_sorted
=
is_sorted
)
@
classmethod
def
from_scipy
(
self
,
mat
):
...
...
@@ -54,8 +54,8 @@ class SparseTensor(object):
value
=
torch
.
from_numpy
(
mat
.
data
)
size
=
mat
.
shape
storage
=
SparseStorage
(
index
,
value
,
size
,
rowptr
=
rowptr
,
colptr
=
colptr
,
is_sorted
=
True
)
storage
=
SparseStorage
(
index
,
value
,
size
,
rowptr
=
rowptr
,
colptr
=
colptr
,
is_sorted
=
True
)
return
SparseTensor
.
from_storage
(
storage
)
...
...
@@ -105,8 +105,8 @@ class SparseTensor(object):
def
is_coalesced
(
self
):
return
self
.
storage
.
is_coalesced
()
def
coalesce
(
self
):
return
self
.
from_storage
(
self
.
storage
.
coalesce
())
def
coalesce
(
self
,
reduce
=
'add'
):
return
self
.
from_storage
(
self
.
storage
.
coalesce
(
reduce
))
def
cached_keys
(
self
):
return
self
.
storage
.
cached_keys
()
...
...
@@ -192,8 +192,8 @@ class SparseTensor(object):
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
cpu
()))
def
cuda
(
self
,
device
=
None
,
non_blocking
=
False
,
**
kwargs
):
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
cuda
(
device
,
non_blocking
,
**
kwargs
))
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
cuda
(
device
,
non_blocking
,
**
kwargs
))
return
self
.
from_storage
(
storage
)
@
property
...
...
@@ -215,8 +215,8 @@ class SparseTensor(object):
if
dtype
==
self
.
dtype
:
return
self
storage
=
self
.
storage
.
apply_value
(
lambda
x
:
x
.
type
(
dtype
,
non_blocking
,
**
kwargs
))
storage
=
self
.
storage
.
apply_value
(
lambda
x
:
x
.
type
(
dtype
,
non_blocking
,
**
kwargs
))
return
self
.
from_storage
(
storage
)
...
...
@@ -285,12 +285,9 @@ class SparseTensor(object):
def
to_torch_sparse_coo_tensor
(
self
,
dtype
=
None
,
requires_grad
=
False
):
index
,
value
=
self
.
coo
()
return
torch
.
sparse_coo_tensor
(
index
,
value
if
self
.
has_value
()
else
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
),
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
index
,
value
if
self
.
has_value
()
else
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
),
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
None
):
assert
self
.
dim
()
==
2
...
...
@@ -392,11 +389,6 @@ SparseTensor.index_select_nnz = index_select_nnz
SparseTensor
.
masked_select
=
masked_select
SparseTensor
.
masked_select_nnz
=
masked_select_nnz
# def __getitem__(self, idx):
# # Convert int and slice to index tensor
# # Filter list into edge and sparse slice
# raise NotImplementedError
# def remove_diag(self):
# raise NotImplementedError
...
...
@@ -503,20 +495,30 @@ if __name__ == '__main__':
value
=
torch
.
randn
(
data
.
num_edges
,
10
)
mat
=
SparseTensor
(
data
.
edge_index
,
value
)
index
=
torch
.
tensor
([
0
,
1
,
2
])
mask
=
torch
.
zeros
(
data
.
num_nodes
,
dtype
=
torch
.
bool
)
mask
[:
3
]
=
True
print
(
mat
[
1
].
size
())
print
(
mat
[
1
,
1
].
size
())
print
(
mat
[...,
-
1
].
size
())
print
(
mat
[:
10
,
...,
-
1
].
size
())
print
(
mat
[:,
-
1
].
size
())
print
(
mat
[
1
,
:,
-
1
].
size
())
print
(
mat
[
1
:
4
,
1
:
4
].
size
())
print
(
mat
[
index
].
size
())
print
(
mat
[
index
,
index
].
size
())
print
(
mat
[
mask
,
index
].
size
())
index
=
torch
.
tensor
([
[
0
,
1
,
1
,
2
,
2
],
[
1
,
2
,
2
,
2
,
3
],
])
value
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
5
])
mat
=
SparseTensor
(
index
,
value
)
print
(
mat
)
print
(
mat
.
coalesce
())
# index = torch.tensor([0, 1, 2])
# mask = torch.zeros(data.num_nodes, dtype=torch.bool)
# mask[:3] = True
# print(mat[1].size())
# print(mat[1, 1].size())
# print(mat[..., -1].size())
# print(mat[:10, ..., -1].size())
# print(mat[:, -1].size())
# print(mat[1, :, -1].size())
# print(mat[1:4, 1:4].size())
# print(mat[index].size())
# print(mat[index, index].size())
# print(mat[mask, index].size())
# mat[::-1]
# mat[::2]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment