Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b5c5c860
Unverified
Commit
b5c5c860
authored
Jan 13, 2023
by
czkkkkkk
Committed by
GitHub
Jan 13, 2023
Browse files
[Sparse] Refactor matmul interface. (#5162)
* [Sparse] Refactor matmul interface. * Update
parent
9334421d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
135 additions
and
73 deletions
+135
-73
docs/source/api/python/dgl.sparse_v0.rst
docs/source/api/python/dgl.sparse_v0.rst
+1
-1
python/dgl/sparse/matmul.py
python/dgl/sparse/matmul.py
+125
-62
tests/pytorch/sparse/test_matmul.py
tests/pytorch/sparse/test_matmul.py
+9
-10
No files found.
docs/source/api/python/dgl.sparse_v0.rst
View file @
b5c5c860
...
@@ -171,10 +171,10 @@ Matrix Multiplication
...
@@ -171,10 +171,10 @@ Matrix Multiplication
.. autosummary::
.. autosummary::
:toctree: ../../generated/
:toctree: ../../generated/
matmul
spmm
spmm
bspmm
bspmm
spspmm
spspmm
mm
sddmm
sddmm
bsddmm
bsddmm
...
...
python/dgl/sparse/matmul.py
View file @
b5c5c860
...
@@ -8,11 +8,11 @@ from .diag_matrix import diag, DiagMatrix
...
@@ -8,11 +8,11 @@ from .diag_matrix import diag, DiagMatrix
from
.sparse_matrix
import
SparseMatrix
,
val_like
from
.sparse_matrix
import
SparseMatrix
,
val_like
__all__
=
[
"spmm"
,
"bspmm"
,
"spspmm"
,
"m
m
"
]
__all__
=
[
"spmm"
,
"bspmm"
,
"spspmm"
,
"m
atmul
"
]
def
spmm
(
A
:
Union
[
SparseMatrix
,
DiagMatrix
],
X
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
spmm
(
A
:
Union
[
SparseMatrix
,
DiagMatrix
],
X
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Multiply a sparse matrix by a dense matrix.
"""Multiply a sparse matrix by a dense matrix
, equivalent to ``A @ X``
.
Parameters
Parameters
----------
----------
...
@@ -54,7 +54,8 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
...
@@ -54,7 +54,8 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
def
bspmm
(
A
:
Union
[
SparseMatrix
,
DiagMatrix
],
X
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
bspmm
(
A
:
Union
[
SparseMatrix
,
DiagMatrix
],
X
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Multiply a sparse matrix by a dense matrix by batches.
"""Multiply a sparse matrix by a dense matrix by batches, equivalent to
``A @ X``.
Parameters
Parameters
----------
----------
...
@@ -91,14 +92,14 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
...
@@ -91,14 +92,14 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
return
spmm
(
A
,
X
)
return
spmm
(
A
,
X
)
def
_diag_diag_mm
(
A
1
:
DiagMatrix
,
A2
:
DiagMatrix
)
->
DiagMatrix
:
def
_diag_diag_mm
(
A
:
DiagMatrix
,
B
:
DiagMatrix
)
->
DiagMatrix
:
"""Internal function for multiplying a diagonal matrix by a diagonal matrix
"""Internal function for multiplying a diagonal matrix by a diagonal matrix
Parameters
Parameters
----------
----------
A
1
: DiagMatrix
A : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
Matrix of shape (N, M), with values of shape (nnz1)
A2
: DiagMatrix
B
: DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Matrix of shape (M, P), with values of shape (nnz2)
Returns
Returns
...
@@ -106,15 +107,15 @@ def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
...
@@ -106,15 +107,15 @@ def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
DiagMatrix
DiagMatrix
The result of multiplication.
The result of multiplication.
"""
"""
M
,
N
=
A
1
.
shape
M
,
N
=
A
.
shape
N
,
P
=
A2
.
shape
N
,
P
=
B
.
shape
common_diag_len
=
min
(
M
,
N
,
P
)
common_diag_len
=
min
(
M
,
N
,
P
)
new_diag_len
=
min
(
M
,
P
)
new_diag_len
=
min
(
M
,
P
)
diag_val
=
torch
.
zeros
(
new_diag_len
)
diag_val
=
torch
.
zeros
(
new_diag_len
)
diag_val
[:
common_diag_len
]
=
(
diag_val
[:
common_diag_len
]
=
(
A
1
.
val
[:
common_diag_len
]
*
A2
.
val
[:
common_diag_len
]
A
.
val
[:
common_diag_len
]
*
B
.
val
[:
common_diag_len
]
)
)
return
diag
(
diag_val
.
to
(
A
1
.
device
),
(
M
,
P
))
return
diag
(
diag_val
.
to
(
A
.
device
),
(
M
,
P
))
def
_sparse_diag_mm
(
A
,
D
):
def
_sparse_diag_mm
(
A
,
D
):
...
@@ -174,16 +175,17 @@ def _diag_sparse_mm(D, A):
...
@@ -174,16 +175,17 @@ def _diag_sparse_mm(D, A):
def
spspmm
(
def
spspmm
(
A
1
:
Union
[
SparseMatrix
,
DiagMatrix
],
A2
:
Union
[
SparseMatrix
,
DiagMatrix
]
A
:
Union
[
SparseMatrix
,
DiagMatrix
],
B
:
Union
[
SparseMatrix
,
DiagMatrix
]
)
->
Union
[
SparseMatrix
,
DiagMatrix
]:
)
->
Union
[
SparseMatrix
,
DiagMatrix
]:
"""Multiply a sparse matrix by a sparse matrix. The non-zero values of the
"""Multiply a sparse matrix by a sparse matrix, equivalent to ``A @ B``.
two sparse matrices must be 1D.
The non-zero values of the two sparse matrices must be 1D.
Parameters
Parameters
----------
----------
A
1
: SparseMatrix or DiagMatrix
A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz)
Sparse matrix of shape (N, M) with values of shape (nnz)
A2
: SparseMatrix or DiagMatrix
B
: SparseMatrix or DiagMatrix
Sparse matrix of shape (M, P) with values of shape (nnz)
Sparse matrix of shape (M, P) with values of shape (nnz)
Returns
Returns
...
@@ -198,13 +200,13 @@ def spspmm(
...
@@ -198,13 +200,13 @@ def spspmm(
>>> row1 = torch.tensor([0, 1, 1])
>>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1))
>>> val1 = torch.ones(len(row1))
>>> A
1
= from_coo(row1, col1, val1)
>>> A = from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1])
>>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2))
>>> val2 = torch.ones(len(row2))
>>>
A2
= from_coo(row2, col2, val2)
>>>
B
= from_coo(row2, col2, val2)
>>> result = dgl.sparse.spspmm(A
1
,
A2
)
>>> result = dgl.sparse.spspmm(A,
B
)
>>> print(result)
>>> print(result)
SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],
SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],
[1, 2, 0, 1, 2]]),
[1, 2, 0, 1, 2]]),
...
@@ -212,73 +214,134 @@ def spspmm(
...
@@ -212,73 +214,134 @@ def spspmm(
shape=(2, 3), nnz=5)
shape=(2, 3), nnz=5)
"""
"""
assert
isinstance
(
assert
isinstance
(
A
1
,
(
SparseMatrix
,
DiagMatrix
)
A
,
(
SparseMatrix
,
DiagMatrix
)
),
f
"Expect A1 to be a SparseMatrix or DiagMatrix object, got
{
type
(
A
1
)
}
"
),
f
"Expect A1 to be a SparseMatrix or DiagMatrix object, got
{
type
(
A
)
}
"
assert
isinstance
(
assert
isinstance
(
A2
,
(
SparseMatrix
,
DiagMatrix
)
B
,
(
SparseMatrix
,
DiagMatrix
)
),
f
"Expect A2 to be a SparseMatrix or DiagMatrix object, got
{
type
(
A2
)
}
"
),
f
"Expect A2 to be a SparseMatrix or DiagMatrix object, got
{
type
(
B
)
}
"
if
isinstance
(
A
1
,
DiagMatrix
)
and
isinstance
(
A2
,
DiagMatrix
):
if
isinstance
(
A
,
DiagMatrix
)
and
isinstance
(
B
,
DiagMatrix
):
return
_diag_diag_mm
(
A
1
,
A2
)
return
_diag_diag_mm
(
A
,
B
)
if
isinstance
(
A
1
,
DiagMatrix
):
if
isinstance
(
A
,
DiagMatrix
):
return
_diag_sparse_mm
(
A
1
,
A2
)
return
_diag_sparse_mm
(
A
,
B
)
if
isinstance
(
A2
,
DiagMatrix
):
if
isinstance
(
B
,
DiagMatrix
):
return
_sparse_diag_mm
(
A
1
,
A2
)
return
_sparse_diag_mm
(
A
,
B
)
return
SparseMatrix
(
return
SparseMatrix
(
torch
.
ops
.
dgl_sparse
.
spspmm
(
A
1
.
c_sparse_matrix
,
A2
.
c_sparse_matrix
)
torch
.
ops
.
dgl_sparse
.
spspmm
(
A
.
c_sparse_matrix
,
B
.
c_sparse_matrix
)
)
)
def
m
m
(
def
m
atmul
(
A
1
:
Union
[
SparseMatrix
,
DiagMatrix
],
A
:
Union
[
torch
.
Tensor
,
SparseMatrix
,
DiagMatrix
],
A2
:
Union
[
torch
.
Tensor
,
SparseMatrix
,
DiagMatrix
],
B
:
Union
[
torch
.
Tensor
,
SparseMatrix
,
DiagMatrix
],
)
->
Union
[
torch
.
Tensor
,
SparseMatrix
,
DiagMatrix
]:
)
->
Union
[
torch
.
Tensor
,
SparseMatrix
,
DiagMatrix
]:
"""Multiply a sparse/diagonal matrix by a dense/sparse/diagonal matrix.
"""Multiply two dense/sparse/diagonal matrices, equivalent to ``A @ B``.
If an input is a SparseMatrix or DiagMatrix, its non-zero values should
be 1-D.
The supported combinations are shown as follows.
+--------------+--------+------------+--------------+
| A
\\
B | Tensor | DiagMatrix | SparseMatrix |
+--------------+--------+------------+--------------+
| Tensor | ✅ | 🚫 | 🚫 |
+--------------+--------+------------+--------------+
| SparseMatrix | ✅ | ✅ | ✅ |
+--------------+--------+------------+--------------+
| DiagMatrix | ✅ | ✅ | ✅ |
+--------------+--------+------------+--------------+
* If both matrices are torch.Tensor, it calls
\
:func:`torch.matmul()`. The result is a dense matrix.
* If both matrices are sparse or diagonal, it calls
\
:func:`dgl.sparse.spspmm`. The result is a sparse matrix.
* If :attr:`A` is sparse or diagonal while :attr:`B` is dense, it
\
calls :func:`dgl.sparse.spmm`. The result is a dense matrix.
* The operator supports batched sparse-dense matrix multiplication. In
\
this case, the sparse or diagonal matrix :attr:`A` should have shape
\
:math:`(L, M)`, where the non-zero values have a batch dimension
\
:math:`K`. The dense matrix :attr:`B` should have shape
\
:math:`(M, N, K)`. The output is a dense matrix of shape
\
:math:`(L, N, K)`.
* Sparse-sparse matrix multiplication does not support batched computation.
Parameters
Parameters
----------
----------
A1 : SparseMatrix or DiagMatrix
A : torch.Tensor, SparseMatrix or DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
The first matrix.
A2 : torch.Tensor, SparseMatrix, or DiagMatrix
B : torch.Tensor, SparseMatrix, or DiagMatrix
Matrix of shape (M, P). If it is a SparseMatrix or DiagMatrix,
The second matrix.
it should have values of shape (nnz2).
Returns
Returns
-------
-------
torch.Tensor or DiagMatrix or SparseMatrix
torch.Tensor, SparseMatrix or DiagMatrix
The result of multiplication of shape (N, P)
The result matrix
* It is a dense torch tensor if :attr:`A2` is so.
* It is a DiagMatrix object if both :attr:`A1` and :attr:`A2` are so.
* It is a SparseMatrix object otherwise.
Examples
Examples
--------
--------
Multiply a diagonal matrix with a dense matrix.
>>> val = torch.randn(3)
>>> val = torch.randn(3)
>>> A
1
= diag(val)
>>> A = diag(val)
>>>
A2
= torch.randn(3, 2)
>>>
B
= torch.randn(3, 2)
>>> result = dgl.sparse.m
m
(A
1
,
A2
)
>>> result = dgl.sparse.m
atmul
(A,
B
)
>>> print(type(result))
>>> print(type(result))
<class 'torch.Tensor'>
<class 'torch.Tensor'>
>>> print(result.shape)
>>> print(result.shape)
torch.Size([3, 2])
torch.Size([3, 2])
Multiply a sparse matrix with a dense matrix.
>>> row = torch.tensor([0, 1, 1])
>>> col = torch.tensor([1, 0, 1])
>>> val = torch.randn(len(row))
>>> A = from_coo(row, col, val)
>>> X = torch.randn(2, 3)
>>> result = dgl.sparse.matmul(A, X)
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([2, 3])
Multiply a sparse matrix with a sparse matrix.
>>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1))
>>> A = from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2))
>>> B = from_coo(row2, col2, val2)
>>> result = dgl.sparse.matmul(A, B)
>>> print(type(result))
<class 'dgl.sparse.sparse_matrix.SparseMatrix'>
>>> print(result.shape)
(2, 3)
"""
"""
assert
isinstance
(
assert
isinstance
(
A
,
(
torch
.
Tensor
,
SparseMatrix
,
DiagMatrix
)),
(
A1
,
(
SparseMatrix
,
DiagMatrix
)
f
"Expect arg1 to be a torch.Tensor, SparseMatrix, or DiagMatrix object,"
),
f
"Expect arg1 to be a SparseMatrix, or DiagMatrix object, got
{
type
(
A1
)
}
."
f
"got
{
type
(
A
)
}
."
assert
isinstance
(
A2
,
(
torch
.
Tensor
,
SparseMatrix
,
DiagMatrix
)),
(
)
assert
isinstance
(
B
,
(
torch
.
Tensor
,
SparseMatrix
,
DiagMatrix
)),
(
f
"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix"
f
"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix"
f
"object, got
{
type
(
A2
)
}
."
f
"object, got
{
type
(
B
)
}
."
)
if
isinstance
(
A
,
torch
.
Tensor
)
and
isinstance
(
B
,
torch
.
Tensor
):
return
torch
.
matmul
(
A
,
B
)
assert
not
isinstance
(
A
,
torch
.
Tensor
),
(
f
"Expect arg2 to be a torch Tensor if arg 1 is torch Tensor, "
f
"got
{
type
(
B
)
}
."
)
)
if
isinstance
(
A2
,
torch
.
Tensor
):
if
isinstance
(
B
,
torch
.
Tensor
):
return
spmm
(
A
1
,
A2
)
return
spmm
(
A
,
B
)
if
isinstance
(
A
1
,
DiagMatrix
)
and
isinstance
(
A2
,
DiagMatrix
):
if
isinstance
(
A
,
DiagMatrix
)
and
isinstance
(
B
,
DiagMatrix
):
return
_diag_diag_mm
(
A
1
,
A2
)
return
_diag_diag_mm
(
A
,
B
)
return
spspmm
(
A
1
,
A2
)
return
spspmm
(
A
,
B
)
SparseMatrix
.
__matmul__
=
m
m
SparseMatrix
.
__matmul__
=
m
atmul
DiagMatrix
.
__matmul__
=
m
m
DiagMatrix
.
__matmul__
=
m
atmul
tests/pytorch/sparse/test_matmul.py
View file @
b5c5c860
...
@@ -4,7 +4,8 @@ import backend as F
...
@@ -4,7 +4,8 @@ import backend as F
import
pytest
import
pytest
import
torch
import
torch
from
dgl.sparse
import
bspmm
,
diag
,
from_coo
,
mm
,
val_like
from
dgl.sparse
import
bspmm
,
diag
,
from_coo
,
val_like
from
dgl.sparse.matmul
import
matmul
from
.utils
import
(
from
.utils
import
(
clone_detach_and_grad
,
clone_detach_and_grad
,
...
@@ -33,7 +34,7 @@ def test_spmm(create_func, shape, nnz, out_dim):
...
@@ -33,7 +34,7 @@ def test_spmm(create_func, shape, nnz, out_dim):
else
:
else
:
X
=
torch
.
randn
(
shape
[
1
],
requires_grad
=
True
,
device
=
dev
)
X
=
torch
.
randn
(
shape
[
1
],
requires_grad
=
True
,
device
=
dev
)
sparse_result
=
A
@
X
sparse_result
=
matmul
(
A
,
X
)
grad
=
torch
.
randn_like
(
sparse_result
)
grad
=
torch
.
randn_like
(
sparse_result
)
sparse_result
.
backward
(
grad
)
sparse_result
.
backward
(
grad
)
...
@@ -60,7 +61,7 @@ def test_bspmm(create_func, shape, nnz):
...
@@ -60,7 +61,7 @@ def test_bspmm(create_func, shape, nnz):
A
=
create_func
(
shape
,
nnz
,
dev
,
2
)
A
=
create_func
(
shape
,
nnz
,
dev
,
2
)
X
=
torch
.
randn
(
shape
[
1
],
10
,
2
,
requires_grad
=
True
,
device
=
dev
)
X
=
torch
.
randn
(
shape
[
1
],
10
,
2
,
requires_grad
=
True
,
device
=
dev
)
sparse_result
=
bspmm
(
A
,
X
)
sparse_result
=
matmul
(
A
,
X
)
grad
=
torch
.
randn_like
(
sparse_result
)
grad
=
torch
.
randn_like
(
sparse_result
)
sparse_result
.
backward
(
grad
)
sparse_result
.
backward
(
grad
)
...
@@ -92,7 +93,7 @@ def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
...
@@ -92,7 +93,7 @@ def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
shape2
=
(
shape_n_m
[
1
],
shape_k
)
shape2
=
(
shape_n_m
[
1
],
shape_k
)
A1
=
create_func1
(
shape1
,
nnz1
,
dev
)
A1
=
create_func1
(
shape1
,
nnz1
,
dev
)
A2
=
create_func2
(
shape2
,
nnz2
,
dev
)
A2
=
create_func2
(
shape2
,
nnz2
,
dev
)
A3
=
A1
@
A2
A3
=
matmul
(
A1
,
A2
)
grad
=
torch
.
randn_like
(
A3
.
val
)
grad
=
torch
.
randn_like
(
A3
.
val
)
A3
.
val
.
backward
(
grad
)
A3
.
val
.
backward
(
grad
)
...
@@ -132,14 +133,14 @@ def test_spspmm_duplicate():
...
@@ -132,14 +133,14 @@ def test_spspmm_duplicate():
A2
=
from_coo
(
row
,
col
,
val
,
shape
)
A2
=
from_coo
(
row
,
col
,
val
,
shape
)
try
:
try
:
A1
@
A2
matmul
(
A1
,
A2
)
except
:
except
:
pass
pass
else
:
else
:
assert
False
,
"Should raise error."
assert
False
,
"Should raise error."
try
:
try
:
A2
@
A1
matmul
(
A2
,
A1
)
except
:
except
:
pass
pass
else
:
else
:
...
@@ -155,8 +156,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz):
...
@@ -155,8 +156,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz):
A
=
create_func
(
sparse_shape
,
nnz
,
dev
)
A
=
create_func
(
sparse_shape
,
nnz
,
dev
)
diag_val
=
torch
.
randn
(
sparse_shape
[
1
],
device
=
dev
,
requires_grad
=
True
)
diag_val
=
torch
.
randn
(
sparse_shape
[
1
],
device
=
dev
,
requires_grad
=
True
)
D
=
diag
(
diag_val
,
diag_shape
)
D
=
diag
(
diag_val
,
diag_shape
)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul
B
=
matmul
(
A
,
D
)
B
=
mm
(
A
,
D
)
grad
=
torch
.
randn_like
(
B
.
val
)
grad
=
torch
.
randn_like
(
B
.
val
)
B
.
val
.
backward
(
grad
)
B
.
val
.
backward
(
grad
)
...
@@ -189,8 +189,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz):
...
@@ -189,8 +189,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz):
A
=
create_func
(
sparse_shape
,
nnz
,
dev
)
A
=
create_func
(
sparse_shape
,
nnz
,
dev
)
diag_val
=
torch
.
randn
(
sparse_shape
[
0
],
device
=
dev
,
requires_grad
=
True
)
diag_val
=
torch
.
randn
(
sparse_shape
[
0
],
device
=
dev
,
requires_grad
=
True
)
D
=
diag
(
diag_val
,
diag_shape
)
D
=
diag
(
diag_val
,
diag_shape
)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul
B
=
matmul
(
D
,
A
)
B
=
mm
(
D
,
A
)
grad
=
torch
.
randn_like
(
B
.
val
)
grad
=
torch
.
randn_like
(
B
.
val
)
B
.
val
.
backward
(
grad
)
B
.
val
.
backward
(
grad
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment