Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b3aec7ae
Unverified
Commit
b3aec7ae
authored
Dec 22, 2022
by
czkkkkkk
Committed by
GitHub
Dec 22, 2022
Browse files
[Sparse] Add mock_sddmm in mock_sparse (#5059)
parent
56ce60b0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
2 deletions
+55
-2
python/dgl/mock_sparse/sddmm.py
python/dgl/mock_sparse/sddmm.py
+55
-2
No files found.
python/dgl/mock_sparse/sddmm.py
View file @
b3aec7ae
"""Sampled Dense-Dense Matrix Multiplication (SDDMM) operator module."""
"""Sampled Dense-Dense Matrix Multiplication (SDDMM) operator module."""
import
torch
import
torch
from
.sp_matrix
import
SparseMatrix
from
.sp_matrix
import
create_from_coo
,
SparseMatrix
__all__
=
[
"sddmm"
]
__all__
=
[
"sddmm"
,
"mock_bsddmm"
]
def
sddmm
(
def
sddmm
(
...
@@ -56,3 +56,56 @@ def sddmm(
...
@@ -56,3 +56,56 @@ def sddmm(
# PyTorch's sddmm operator only supports CSR format.
# PyTorch's sddmm operator only supports CSR format.
res
=
torch
.
sparse
.
sampled_addmm
(
A
.
adj
.
to_sparse_csr
(),
mat1
,
mat2
)
res
=
torch
.
sparse
.
sampled_addmm
(
A
.
adj
.
to_sparse_csr
(),
mat1
,
mat2
)
return
SparseMatrix
(
A
.
row
,
A
.
col
,
res
.
values
(),
A
.
adj
.
shape
)
return
SparseMatrix
(
A
.
row
,
A
.
col
,
res
.
values
(),
A
.
adj
.
shape
)
def
mock_bsddmm
(
A
:
SparseMatrix
,
mat1
:
torch
.
Tensor
,
mat2
:
torch
.
Tensor
)
->
SparseMatrix
:
r
"""Batched Sampled-Dense-Dense Matrix Multiplication (SDDMM).
``bsddmm`` conducts `sddmm` for each batch of the two dense matrices
independently.
In particular, :attr:``mat1`` and :attr:``mat2`` can be 2-D, which will be
reshape as `(B, M, 1)` and `(B, 1, K)` in the computation.
Parameters
----------
A : SparseMatrix
Sparse matrix of shape `(M, N)`.
mat1 : Tensor
Dense matrix of shape `(B, M, K)` or `(B, M,)`
mat2 : Tensor
Dense matrix of shape `(B, K, N)` or `(B, K,)`
Returns
-------
SparseMatrix
Sparse matrix of shape `(M, N)` with non-zero values of `B` dimension.
Examples
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([2, 3, 3])
>>> val = torch.arange(1, 4).float()
>>> A = create_from_coo(row, col, val, (3, 4))
>>> mat1 = torch.randn(2, 3, 5)
>>> mat2 = torch.randn(2, 5, 4)
>>> dgl.mock_sparse.mock_bsddmm(A, mat1, mat2)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 3, 3]]),
values=tensor([[-0.6765, -0.4017],
[ 3.3290, 6.9016],
[ 4.8184, 5.8882]]),
shape=(3, 4), nnz=3)
"""
batch_mat1
=
[
mat1
[
i
,
...]
for
i
in
range
(
mat1
.
shape
[
0
])]
batch_mat2
=
[
mat2
[
i
,
...]
for
i
in
range
(
mat2
.
shape
[
0
])]
batch_ret
=
[
sddmm
(
A
,
lhs
,
rhs
)
for
lhs
,
rhs
in
zip
(
batch_mat1
,
batch_mat2
)]
return
create_from_coo
(
row
=
A
.
row
,
col
=
A
.
col
,
val
=
torch
.
stack
([
sp_mat
.
val
for
sp_mat
in
batch_ret
],
dim
=-
1
),
shape
=
A
.
shape
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment