Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
7da1c4c1
Commit
7da1c4c1
authored
Aug 05, 2018
by
rusty1s
Browse files
restructure
parent
52dcc2e5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
40 deletions
+48
-40
cuda/matmul_cuda.cu
cuda/matmul_cuda.cu
+1
-4
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-3
torch_sparse/coalesce.py
torch_sparse/coalesce.py
+1
-2
torch_sparse/matmul.py
torch_sparse/matmul.py
+33
-31
torch_sparse/transpose.py
torch_sparse/transpose.py
+11
-0
No files found.
cuda/matmul_cuda.cu
View file @
7da1c4c1
...
...
@@ -30,12 +30,9 @@ static void init_cusparse() {
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm_cuda
(
at
::
Tensor
A
,
at
::
Tensor
B
)
{
init_cusparse
();
A
=
A
.
coalesce
();
B
=
B
.
coalesce
();
auto
m
=
A
.
size
(
0
);
auto
n
=
B
.
size
(
1
);
auto
k
=
A
.
size
(
1
);
auto
n
=
B
.
size
(
1
);
auto
nnzA
=
A
.
_nnz
();
auto
nnzB
=
B
.
_nnz
();
...
...
torch_sparse/__init__.py
View file @
7da1c4c1
from
.coalesce
import
coalesce
from
.
spar
se
import
sparse_coo_tensor
,
to_valu
e
from
.
transpo
se
import
transpos
e
from
.matmul
import
spspmm
__all__
=
[
'coalesce'
,
'sparse_coo_tensor'
,
'to_value'
,
'transpose'
,
'spspmm'
,
]
torch_sparse/coalesce.py
View file @
7da1c4c1
...
...
@@ -2,8 +2,7 @@ import torch
import
torch_scatter
def
coalesce
(
index
,
value
,
size
,
op
=
'add'
,
fill_value
=
0
):
m
,
n
=
size
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
'add'
,
fill_value
=
0
):
row
,
col
=
index
unique
,
inv
=
torch
.
unique
(
row
*
n
+
col
,
sorted
=
True
,
return_inverse
=
True
)
...
...
torch_sparse/matmul.py
View file @
7da1c4c1
import
torch
import
scipy.sparse
from
torch_sparse
import
transpose
if
torch
.
cuda
.
is_available
():
import
matmul_cuda
...
...
@@ -9,53 +10,54 @@ class SpSpMM(torch.autograd.Function):
"""Sparse matrix product of two sparse tensors with autograd support."""
@
staticmethod
def
forward
(
ctx
,
A
,
B
):
ctx
.
save_for_backward
(
A
,
B
)
return
mm
(
A
,
B
)
def
forward
(
ctx
,
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
):
indexC
,
valueC
=
mm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
)
ctx
.
m
,
ctx
.
k
,
ctx
.
n
=
m
,
k
,
n
ctx
.
save_for_backward
(
indexA
,
valueA
,
indexB
,
valueB
,
indexC
)
return
indexC
,
valueC
@
staticmethod
def
backward
(
ctx
,
grad_C
):
A
,
B
=
ctx
.
saved_variables
grad_A
=
grad_B
=
None
def
backward
(
ctx
,
grad_
indexC
,
grad_value
C
):
m
,
k
,
n
=
ctx
.
m
,
ctx
.
k
,
ctx
.
n
indexA
,
valueA
,
indexB
,
valueB
,
indexC
=
ctx
.
saved_variables
if
ctx
.
needs_input_grad
[
0
]:
grad_A
=
mm
(
grad_C
,
B
.
t
().
coalesce
())
grad_valueA
=
grad_valueB
=
None
if
ctx
.
needs_input_grad
[
1
]:
grad_B
=
mm
(
A
.
t
(),
grad_C
)
indexB
,
valueB
=
transpose
(
indexB
,
valueB
,
k
,
n
)
_
,
grad_valueA
=
mm
(
indexC
,
grad_valueC
,
indexB
,
valueB
,
m
,
n
,
k
)
# TODO: Filter values.
return
grad_A
,
grad_B
if
ctx
.
needs_input_grad
[
4
]:
indexA
,
valueA
=
transpose
(
indexA
,
valueA
,
m
,
k
)
_
,
grad_valueB
=
mm
(
indexA
,
valueA
,
indexC
,
grad_valueC
,
k
,
m
,
n
)
# TODO: Filter values.
return
None
,
grad_valueA
,
None
,
grad_valueB
,
None
,
None
,
None
spspmm
=
SpSpMM
.
apply
spspmm
=
SpSpMM
.
apply
def
mm
(
A
,
B
):
assert
A
.
dtype
==
B
.
dtype
assert
A
.
size
(
1
)
==
B
.
size
(
0
)
return
mm_cuda
(
A
,
B
)
if
A
.
is_cuda
else
mm_cpu
(
A
,
B
)
def
mm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
):
assert
valueA
.
dtype
==
valueB
.
dtype
def
mm_cuda
(
A
,
B
):
index
,
value
=
matmul_cuda
.
spspmm
(
A
,
B
)
size
=
torch
.
Size
([
A
.
size
(
0
),
B
.
size
(
1
)])
return
torch
.
sparse_coo_tensor
(
index
,
value
,
size
,
device
=
value
.
device
)
if
indexA
.
is_cuda
:
return
matmul_cuda
.
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
)
A
=
to_scipy
(
indexA
,
valueA
,
m
,
k
)
B
=
to_scipy
(
indexB
,
valueB
,
k
,
n
)
indexC
,
valueC
=
from_scipy
(
A
.
tocsr
().
dot
(
B
.
tocsr
()).
tocoo
())
def
mm_cpu
(
A
,
B
):
return
from_scipy
(
to_scipy
(
A
).
dot
(
to_scipy
(
B
)))
return
indexC
,
valueC
def
to_scipy
(
A
):
(
row
,
col
),
data
,
shape
=
A
.
_indices
(),
A
.
_values
(),
tuple
(
A
.
size
())
row
,
col
,
data
=
row
.
detach
(),
col
.
detach
(),
data
.
detach
()
return
scipy
.
sparse
.
coo_matrix
((
data
,
(
row
,
col
)),
shape
).
tocsr
()
def
to_scipy
(
index
,
value
,
m
,
n
):
(
row
,
col
),
data
=
index
.
detach
(),
value
.
detach
()
return
scipy
.
sparse
.
coo_matrix
((
data
,
(
row
,
col
)),
(
m
,
n
))
def
from_scipy
(
A
):
A
=
A
.
tocoo
()
row
,
col
,
value
,
size
=
A
.
row
,
A
.
col
,
A
.
data
,
torch
.
Size
(
A
.
shape
)
row
,
col
=
torch
.
from_numpy
(
row
).
long
(),
torch
.
from_numpy
(
col
).
long
()
value
=
torch
.
from_numpy
(
value
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
sparse_coo_tensor
(
index
,
value
,
size
)
row
,
col
,
value
=
A
.
row
,
A
.
col
,
A
.
data
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
).
to
(
torch
.
long
)
return
index
,
value
torch_sparse/transpose.py
0 → 100644
View file @
7da1c4c1
import
torch
from
torch_sparse
import
coalesce
def
transpose
(
index
,
value
,
m
,
n
):
row
,
col
=
index
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)
index
,
value
=
coalesce
(
index
,
value
,
m
,
n
)
return
index
,
value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment