Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
a1c268a5
Commit
a1c268a5
authored
Jan 26, 2020
by
rusty1s
Browse files
fix matmul
parent
b6a1f005
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
56 deletions
+15
-56
test/test_convert.py
test/test_convert.py
+0
-37
test/test_matmul.py
test/test_matmul.py
+1
-1
torch_sparse/cat.py
torch_sparse/cat.py
+3
-3
torch_sparse/matmul.py
torch_sparse/matmul.py
+11
-15
No files found.
test/test_convert.py
View file @
a1c268a5
import
time
import
torch
from
torch_sparse
import
to_scipy
,
from_scipy
from
torch_sparse
import
to_torch_sparse
,
from_torch_sparse
from
torch_sparse.storage
import
SparseStorage
from
scipy.io
import
loadmat
def
test_convert_scipy
():
...
...
@@ -24,37 +21,3 @@ def test_convert_torch_sparse():
out
=
from_torch_sparse
(
to_torch_sparse
(
index
,
value
,
N
,
N
).
coalesce
())
assert
out
[
0
].
tolist
()
==
index
.
tolist
()
assert
out
[
1
].
tolist
()
==
value
.
tolist
()
def
test_ind2ptr
():
name
=
(
'DIMACS10'
,
'citationCiteseer'
)[
1
]
mat
=
loadmat
(
f
'benchmark/
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
]
mat
=
mat
.
tocsr
().
tocoo
()
mat
=
mat
.
tocsr
()
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
).
cuda
()
mat
=
mat
.
tocoo
()
row
=
torch
.
from_numpy
(
mat
.
row
).
to
(
torch
.
long
).
cuda
()
col
=
torch
.
from_numpy
(
mat
.
col
).
to
(
torch
.
long
).
cuda
()
storage
=
SparseStorage
(
row
=
row
,
col
=
col
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
storage
.
rowptr
storage
.
_rowptr
=
None
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
assert
storage
.
rowptr
.
tolist
()
==
rowptr
.
tolist
()
storage
=
SparseStorage
(
rowptr
=
rowptr
,
col
=
col
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
storage
.
row
storage
.
_row
=
None
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
assert
storage
.
row
.
tolist
()
==
row
.
tolist
()
test/test_matmul.py
View file @
a1c268a5
...
...
@@ -19,7 +19,7 @@ def test_spmm(dtype, device, reduce):
src
[
2
:
4
,
:]
=
0
# Remove multiple rows.
src
[:,
2
:
4
]
=
0
# Remove multiple columns.
src
=
SparseTensor
.
from_dense
(
src
).
requires_grad_
()
(
row
,
col
)
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
other
=
torch
.
randn
((
2
,
8
,
2
),
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
...
...
torch_sparse/cat.py
View file @
a1c268a5
...
...
@@ -22,7 +22,7 @@ def cat(tensors, dim):
if
dim
==
0
:
for
tensor
in
tensors
:
(
row
,
col
)
,
value
=
tensor
.
coo
()
row
,
col
,
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
cols
+=
[
col
]
values
+=
[
value
]
...
...
@@ -48,7 +48,7 @@ def cat(tensors, dim):
elif
dim
==
1
:
for
tensor
in
tensors
:
(
row
,
col
)
,
value
=
tensor
.
coo
()
row
,
col
,
value
=
tensor
.
coo
()
rows
+=
[
row
]
cols
+=
[
col
+
sparse_size
[
1
]]
values
+=
[
value
]
...
...
@@ -76,7 +76,7 @@ def cat(tensors, dim):
elif
dim
==
(
0
,
1
)
or
dim
==
(
1
,
0
):
for
tensor
in
tensors
:
(
row
,
col
)
,
value
=
tensor
.
coo
()
row
,
col
,
value
=
tensor
.
coo
()
rows
+=
[
row
+
sparse_size
[
0
]]
cols
+=
[
col
+
sparse_size
[
1
]]
values
+=
[
value
]
if
has_value
else
[]
...
...
torch_sparse/matmul.py
View file @
a1c268a5
...
...
@@ -40,24 +40,20 @@ class SPMM(torch.autograd.Function):
arg_out
)
=
ctx
.
saved_tensors
invalid_arg_mask
=
arg_out_ind
=
None
if
ctx
.
reduce
in
[
'min'
,
'max'
]
and
(
ctx
.
needs_input_grad
[
5
]
or
ctx
.
needs_input_grad
[
6
]):
invalid_arg_mask
=
arg_out
==
row
.
size
(
0
)
if
ctx
.
reduce
in
[
'min'
,
'max'
]
and
(
ctx
.
needs_input_grad
[
3
]
or
ctx
.
needs_input_grad
[
4
]):
invalid_arg_mask
=
arg_out
==
col
.
size
(
0
)
arg_out_ind
=
arg_out
.
masked_fill
(
invalid_arg_mask
,
-
1
)
grad_value
=
None
if
ctx
.
needs_input_grad
[
3
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
row
,
rowptr
,
col
,
mat
,
grad_out
,
ctx
.
reduce
)
if
ctx
.
reduce
==
'mean'
:
if
ctx
.
reduce
in
[
'sum'
,
'add'
,
'mean'
]:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
row
,
rowptr
,
col
,
mat
,
grad_out
,
ctx
.
reduce
)
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
col
=
col
[
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
out
=
mat
.
gather
(
-
2
,
col
).
mul_
(
grad_out
)
col
_tmp
=
col
[
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
out
=
mat
.
gather
(
-
2
,
col
_tmp
).
mul_
(
grad_out
)
out
.
masked_fill_
(
invalid_arg_mask
,
0
)
grad_value
=
scatter_add
(
out
.
flatten
(),
arg_out
.
flatten
(),
dim
=
0
,
dim_size
=
value
.
numel
()
+
1
)
...
...
@@ -85,8 +81,8 @@ class SPMM(torch.autograd.Function):
else
:
value
=
grad_out
value
.
masked_fill_
(
invalid_arg_mask
,
0
)
col
=
col
[
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
grad_mat
=
scatter_add
(
value
,
col
,
dim
=-
2
,
col
_tmp
=
col
[
arg_out_ind
.
flatten
()].
view_as
(
arg_out
)
grad_mat
=
scatter_add
(
value
,
col
_tmp
,
dim
=-
2
,
dim_size
=
mat
.
size
(
-
2
))
return
None
,
None
,
None
,
grad_value
,
grad_mat
,
None
,
None
,
None
,
None
...
...
@@ -119,7 +115,7 @@ class SPSPMM(torch.autograd.Function):
rowptrC
=
torch
.
from_numpy
(
C
.
indptr
).
to
(
torch
.
int64
)
colC
=
torch
.
from_numpy
(
C
.
indices
).
to
(
torch
.
int64
)
valueC
=
torch
.
from_numpy
(
C
.
data
)
valueC
=
valueC
.
to
(
dtype
)
if
dtype
is
not
None
else
valueC
valueC
=
valueC
.
to
(
dtype
)
if
dtype
is
not
None
else
None
ctx
.
mark_non_differentiable
(
rowptrC
,
colC
)
...
...
@@ -152,8 +148,8 @@ def matmul(src, other, reduce='sum'):
rowptr
,
col
,
value
=
src
.
csr
()
row
=
None
if
reduce
in
[
'sum'
,
'add'
]
and
(
src
.
requires_grad
or
other
.
reuqires_grad
):
if
reduce
in
[
'sum'
,
'add'
,
'mean'
]
and
(
src
.
requires_grad
or
other
.
reuqires_grad
):
row
=
src
.
storage
.
row
rowcount
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment