Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
9fe44a44
Commit
9fe44a44
authored
Dec 18, 2019
by
rusty1s
Browse files
test code
parent
ac5d7a78
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
9 deletions
+84
-9
torch_sparse/sparse.py
torch_sparse/sparse.py
+7
-9
torch_sparse/spmm.py
torch_sparse/spmm.py
+77
-0
No files found.
torch_sparse/sparse.py
View file @
9fe44a44
...
@@ -17,8 +17,8 @@ def __is_scalar__(x):
...
@@ -17,8 +17,8 @@ def __is_scalar__(x):
class
SparseTensor
(
object
):
class
SparseTensor
(
object
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
self
.
_storage
=
SparseStorage
(
self
.
_storage
=
SparseStorage
(
index
[
0
],
index
[
1
],
value
,
sparse_size
,
index
[
0
],
index
[
1
],
value
,
sparse_size
,
is_sorted
=
is_sorted
)
is_sorted
=
is_sorted
)
@
classmethod
@
classmethod
def
from_storage
(
self
,
storage
):
def
from_storage
(
self
,
storage
):
...
@@ -184,8 +184,8 @@ class SparseTensor(object):
...
@@ -184,8 +184,8 @@ class SparseTensor(object):
if
self
.
has_value
:
if
self
.
has_value
:
return
self
.
set_value
(
self
.
_value
+
other
,
'coo'
)
return
self
.
set_value
(
self
.
_value
+
other
,
'coo'
)
else
:
else
:
return
self
.
set_value
(
return
self
.
set_value
(
torch
.
full
((
self
.
nnz
(),
),
other
+
1
),
torch
.
full
((
self
.
nnz
(),
),
other
+
1
),
'coo'
)
'coo'
)
elif
torch
.
is_tensor
(
other
):
elif
torch
.
is_tensor
(
other
):
if
layout
is
None
:
if
layout
is
None
:
layout
=
'coo'
layout
=
'coo'
...
@@ -249,9 +249,7 @@ class SparseTensor(object):
...
@@ -249,9 +249,7 @@ class SparseTensor(object):
return
torch
.
sparse_coo_tensor
(
return
torch
.
sparse_coo_tensor
(
index
,
index
,
torch
.
ones_like
(
self
.
_row
,
dtype
)
if
value
is
None
else
value
,
torch
.
ones_like
(
self
.
_row
,
dtype
)
if
value
is
None
else
value
,
self
.
size
(),
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
device
=
self
.
device
,
requires_grad
=
requires_grad
)
def
__repr__
(
self
):
def
__repr__
(
self
):
i
=
' '
*
6
i
=
' '
*
6
...
@@ -292,8 +290,8 @@ if __name__ == '__main__':
...
@@ -292,8 +290,8 @@ if __name__ == '__main__':
print
(
mat1
)
print
(
mat1
)
mat1
=
mat1
.
t
()
mat1
=
mat1
.
t
()
mat2
=
torch
.
sparse_coo_tensor
(
mat2
=
torch
.
sparse_coo_tensor
(
data
.
edge_index
,
torch
.
ones
(
data
.
num_edges
),
data
.
edge_index
,
torch
.
ones
(
data
.
num_edges
),
device
=
device
)
device
=
device
)
mat2
=
mat2
.
coalesce
()
mat2
=
mat2
.
coalesce
()
mat2
=
mat2
.
t
().
coalesce
()
mat2
=
mat2
.
t
().
coalesce
()
...
...
torch_sparse/spmm.py
View file @
9fe44a44
import
torch
from
torch_scatter
import
scatter_add
from
torch_scatter
import
scatter_add
from
torch_sparse.sparse
import
SparseTensor
if
torch
.
cuda
.
is_available
():
import
torch_sparse.spmm_cuda
def
spmm_
(
sparse_mat
,
mat
,
reduce
=
'add'
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
assert
sparse_mat
.
dim
()
==
2
and
mat
.
dim
()
==
2
assert
sparse_mat
.
size
(
1
)
==
mat
.
size
(
0
)
rowptr
,
col
,
value
=
sparse_mat
.
csr
()
mat
=
mat
.
contiguous
()
if
reduce
in
[
'add'
,
'mean'
]:
return
torch_sparse
.
spmm_cuda
.
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
else
:
return
torch_sparse
.
spmm_cuda
.
spmm_arg
(
rowptr
,
col
,
value
,
mat
,
reduce
)
def
spmm
(
index
,
value
,
m
,
n
,
matrix
):
def
spmm
(
index
,
value
,
m
,
n
,
matrix
):
"""Matrix product of sparse matrix with dense matrix.
"""Matrix product of sparse matrix with dense matrix.
...
@@ -24,3 +44,60 @@ def spmm(index, value, m, n, matrix):
...
@@ -24,3 +44,60 @@ def spmm(index, value, m, n, matrix):
out
=
scatter_add
(
out
,
row
,
dim
=
0
,
dim_size
=
m
)
out
=
scatter_add
(
out
,
row
,
dim
=
0
,
dim_size
=
m
)
return
out
return
out
if
__name__
==
'__main__'
:
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
row
=
torch
.
tensor
([
0
,
0
,
0
,
1
,
1
,
1
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
2
],
device
=
device
)
value
=
torch
.
ones_like
(
col
,
dtype
=
torch
.
float
,
device
=
device
)
value
=
None
sparse_mat
=
SparseTensor
(
torch
.
stack
([
row
,
col
],
dim
=
0
),
value
)
mat
=
torch
.
tensor
([[
1
,
4
],
[
2
,
5
],
[
3
,
6
]],
dtype
=
torch
.
float
,
device
=
device
)
out1
=
spmm_
(
sparse_mat
,
mat
,
reduce
=
'add'
)
out2
=
sparse_mat
.
to_dense
()
@
mat
assert
torch
.
allclose
(
out1
,
out2
)
from
torch_geometric.datasets
import
Reddit
,
Planetoid
# noqa
import
time
# noqa
# Warmup
x
=
torch
.
randn
((
1000
,
1000
),
device
=
device
)
for
_
in
range
(
100
):
x
.
sum
()
# dataset = Reddit('/tmp/Reddit')
dataset
=
Planetoid
(
'/tmp/PubMed'
,
'PubMed'
)
# dataset = Planetoid('/tmp/Cora', 'Cora')
data
=
dataset
[
0
].
to
(
device
)
mat
=
torch
.
randn
((
data
.
num_nodes
,
1024
),
device
=
device
)
value
=
torch
.
ones
(
data
.
num_edges
,
device
=
device
)
sparse_mat
=
SparseTensor
(
data
.
edge_index
,
value
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out1
=
spmm_
(
sparse_mat
,
mat
,
reduce
=
'add'
)
out1
=
out1
[
0
]
if
isinstance
(
out1
,
tuple
)
else
out1
torch
.
cuda
.
synchronize
()
print
(
'My: '
,
time
.
perf_counter
()
-
t
)
sparse_mat
=
torch
.
sparse_coo_tensor
(
data
.
edge_index
,
value
)
sparse_mat
=
sparse_mat
.
coalesce
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out2
=
sparse_mat
@
mat
torch
.
cuda
.
synchronize
()
print
(
'Torch: '
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
spmm
(
data
.
edge_index
,
value
,
data
.
num_nodes
,
data
.
num_nodes
,
mat
)
torch
.
cuda
.
synchronize
()
print
(
'Scatter:'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment