Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
bff32a09
Unverified
Commit
bff32a09
authored
Jan 12, 2023
by
czkkkkkk
Committed by
GitHub
Jan 12, 2023
Browse files
[Sparse] Use efficient implementation for Diag @ Sparse and Sparse @ Diag. (#5147)
parent
21c4c29a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
128 additions
and
4 deletions
+128
-4
python/dgl/sparse/matmul.py
python/dgl/sparse/matmul.py
+59
-3
tests/pytorch/sparse/test_matmul.py
tests/pytorch/sparse/test_matmul.py
+69
-1
No files found.
python/dgl/sparse/matmul.py
View file @
bff32a09
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
.diag_matrix
import
diag
,
DiagMatrix
from
.diag_matrix
import
diag
,
DiagMatrix
from
.sparse_matrix
import
SparseMatrix
from
.sparse_matrix
import
SparseMatrix
,
val_like
__all__
=
[
"spmm"
,
"bspmm"
,
"spspmm"
,
"mm"
]
__all__
=
[
"spmm"
,
"bspmm"
,
"spspmm"
,
"mm"
]
...
@@ -117,6 +117,62 @@ def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
...
@@ -117,6 +117,62 @@ def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
return
diag
(
diag_val
.
to
(
A1
.
device
),
(
M
,
P
))
return
diag
(
diag_val
.
to
(
A1
.
device
),
(
M
,
P
))
def
_sparse_diag_mm
(
A
,
D
):
"""Internal function for multiplying a sparse matrix by a diagonal matrix.
Parameters
----------
A : SparseMatrix
Matrix of shape (N, M), with values of shape (nnz1)
D : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Returns
-------
SparseMatrix
SparseMatrix with shape (N, P)
"""
assert
(
A
.
shape
[
1
]
==
D
.
shape
[
0
]
),
f
"The second dimension of SparseMatrix should be equal to the first
\
dimension of DiagMatrix in matmul(SparseMatrix, DiagMatrix), but the
\
shapes of SparseMatrix and DiagMatrix are
{
A
.
shape
}
and
{
D
.
shape
}
\
respectively."
assert
(
D
.
shape
[
0
]
==
D
.
shape
[
1
]
),
f
"The DiagMatrix should be a square in matmul(SparseMatrix, DiagMatrix)
\
but got
{
D
.
shape
}
"
return
val_like
(
A
,
D
.
val
[
A
.
col
]
*
A
.
val
)
def
_diag_sparse_mm
(
D
,
A
):
"""Internal function for multiplying a diag matrix by a sparse matrix.
Parameters
----------
D : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Returns
-------
SparseMatrix
SparseMatrix with shape (N, P)
"""
assert
(
D
.
shape
[
1
]
==
A
.
shape
[
0
]
),
f
"The second dimension of DiagMatrix should be equal to the first
\
dimension of SparseMatrix in matmul(DiagMatrix, SparseMatrix), but the
\
shapes of DiagMatrix and SparseMatrix are
{
D
.
shape
}
and
{
A
.
shape
}
\
respectively."
assert
(
D
.
shape
[
0
]
==
D
.
shape
[
1
]
),
f
"The DiagMatrix should be a square in matmul(DiagMatrix, SparseMatrix)
\
but got
{
D
.
shape
}
"
return
val_like
(
A
,
D
.
val
[
A
.
row
]
*
A
.
val
)
def
spspmm
(
def
spspmm
(
A1
:
Union
[
SparseMatrix
,
DiagMatrix
],
A2
:
Union
[
SparseMatrix
,
DiagMatrix
]
A1
:
Union
[
SparseMatrix
,
DiagMatrix
],
A2
:
Union
[
SparseMatrix
,
DiagMatrix
]
)
->
Union
[
SparseMatrix
,
DiagMatrix
]:
)
->
Union
[
SparseMatrix
,
DiagMatrix
]:
...
@@ -165,9 +221,9 @@ def spspmm(
...
@@ -165,9 +221,9 @@ def spspmm(
if
isinstance
(
A1
,
DiagMatrix
)
and
isinstance
(
A2
,
DiagMatrix
):
if
isinstance
(
A1
,
DiagMatrix
)
and
isinstance
(
A2
,
DiagMatrix
):
return
_diag_diag_mm
(
A1
,
A2
)
return
_diag_diag_mm
(
A1
,
A2
)
if
isinstance
(
A1
,
DiagMatrix
):
if
isinstance
(
A1
,
DiagMatrix
):
A1
=
A1
.
as_sparse
(
)
return
_diag_sparse_mm
(
A1
,
A2
)
if
isinstance
(
A2
,
DiagMatrix
):
if
isinstance
(
A2
,
DiagMatrix
):
A2
=
A2
.
as_sparse
(
)
return
_sparse_diag_mm
(
A1
,
A2
)
return
SparseMatrix
(
return
SparseMatrix
(
torch
.
ops
.
dgl_sparse
.
spspmm
(
A1
.
c_sparse_matrix
,
A2
.
c_sparse_matrix
)
torch
.
ops
.
dgl_sparse
.
spspmm
(
A1
.
c_sparse_matrix
,
A2
.
c_sparse_matrix
)
)
)
...
...
tests/pytorch/sparse/test_matmul.py
View file @
bff32a09
...
@@ -4,7 +4,7 @@ import backend as F
...
@@ -4,7 +4,7 @@ import backend as F
import
pytest
import
pytest
import
torch
import
torch
from
dgl.sparse
import
bspmm
,
from_coo
,
val_like
from
dgl.sparse
import
bspmm
,
diag
,
from_coo
,
mm
,
val_like
from
.utils
import
(
from
.utils
import
(
clone_detach_and_grad
,
clone_detach_and_grad
,
...
@@ -144,3 +144,71 @@ def test_spspmm_duplicate():
...
@@ -144,3 +144,71 @@ def test_spspmm_duplicate():
pass
pass
else
:
else
:
assert
False
,
"Should raise error."
assert
False
,
"Should raise error."
@
pytest
.
mark
.
parametrize
(
"create_func"
,
[
rand_coo
,
rand_csr
,
rand_csc
])
@
pytest
.
mark
.
parametrize
(
"sparse_shape"
,
[(
5
,
5
),
(
5
,
6
)])
@
pytest
.
mark
.
parametrize
(
"nnz"
,
[
1
,
10
])
def
test_sparse_diag_mm
(
create_func
,
sparse_shape
,
nnz
):
dev
=
F
.
ctx
()
diag_shape
=
sparse_shape
[
1
],
sparse_shape
[
1
]
A
=
create_func
(
sparse_shape
,
nnz
,
dev
)
diag_val
=
torch
.
randn
(
sparse_shape
[
1
],
device
=
dev
,
requires_grad
=
True
)
D
=
diag
(
diag_val
,
diag_shape
)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul
B
=
mm
(
A
,
D
)
grad
=
torch
.
randn_like
(
B
.
val
)
B
.
val
.
backward
(
grad
)
torch_A
=
sparse_matrix_to_torch_sparse
(
A
)
torch_D
=
sparse_matrix_to_torch_sparse
(
D
.
as_sparse
())
torch_B
=
torch
.
sparse
.
mm
(
torch_A
,
torch_D
)
torch_B_grad
=
sparse_matrix_to_torch_sparse
(
B
,
grad
)
torch_B
.
backward
(
torch_B_grad
)
with
torch
.
no_grad
():
assert
torch
.
allclose
(
B
.
dense
(),
torch_B
.
to_dense
(),
atol
=
1e-05
)
assert
torch
.
allclose
(
val_like
(
A
,
A
.
val
.
grad
).
dense
(),
torch_A
.
grad
.
to_dense
(),
atol
=
1e-05
,
)
assert
torch
.
allclose
(
diag
(
D
.
val
.
grad
,
D
.
shape
).
dense
(),
torch_D
.
grad
.
to_dense
(),
atol
=
1e-05
,
)
@
pytest
.
mark
.
parametrize
(
"create_func"
,
[
rand_coo
,
rand_csr
,
rand_csc
])
@
pytest
.
mark
.
parametrize
(
"sparse_shape"
,
[(
5
,
5
),
(
5
,
6
)])
@
pytest
.
mark
.
parametrize
(
"nnz"
,
[
1
,
10
])
def
test_diag_sparse_mm
(
create_func
,
sparse_shape
,
nnz
):
dev
=
F
.
ctx
()
diag_shape
=
sparse_shape
[
0
],
sparse_shape
[
0
]
A
=
create_func
(
sparse_shape
,
nnz
,
dev
)
diag_val
=
torch
.
randn
(
sparse_shape
[
0
],
device
=
dev
,
requires_grad
=
True
)
D
=
diag
(
diag_val
,
diag_shape
)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul
B
=
mm
(
D
,
A
)
grad
=
torch
.
randn_like
(
B
.
val
)
B
.
val
.
backward
(
grad
)
torch_A
=
sparse_matrix_to_torch_sparse
(
A
)
torch_D
=
sparse_matrix_to_torch_sparse
(
D
.
as_sparse
())
torch_B
=
torch
.
sparse
.
mm
(
torch_D
,
torch_A
)
torch_B_grad
=
sparse_matrix_to_torch_sparse
(
B
,
grad
)
torch_B
.
backward
(
torch_B_grad
)
with
torch
.
no_grad
():
assert
torch
.
allclose
(
B
.
dense
(),
torch_B
.
to_dense
(),
atol
=
1e-05
)
assert
torch
.
allclose
(
val_like
(
A
,
A
.
val
.
grad
).
dense
(),
torch_A
.
grad
.
to_dense
(),
atol
=
1e-05
,
)
assert
torch
.
allclose
(
diag
(
D
.
val
.
grad
,
D
.
shape
).
dense
(),
torch_D
.
grad
.
to_dense
(),
atol
=
1e-05
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment