Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7c465d20
Unverified
Commit
7c465d20
authored
Apr 14, 2023
by
czkkkkkk
Committed by
GitHub
Apr 14, 2023
Browse files
[Sparse] Support spspdiv (#5541)
Co-authored-by:
Hongzhi (Steve), Chen
<
chenhongzhi.nkcs@gmail.com
>
parent
bb1f8850
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
115 additions
and
31 deletions
+115
-31
dgl_sparse/include/sparse/elementwise_op.h
dgl_sparse/include/sparse/elementwise_op.h
+18
-6
dgl_sparse/include/sparse/sparse_format.h
dgl_sparse/include/sparse/sparse_format.h
+8
-0
dgl_sparse/src/elemenwise_op.cc
dgl_sparse/src/elemenwise_op.cc
+39
-9
dgl_sparse/src/python_binding.cc
dgl_sparse/src/python_binding.cc
+1
-0
dgl_sparse/src/sparse_format.cc
dgl_sparse/src/sparse_format.cc
+12
-0
python/dgl/sparse/elementwise_op_sp.py
python/dgl/sparse/elementwise_op_sp.py
+12
-12
tests/python/pytorch/sparse/test_elementwise_op.py
tests/python/pytorch/sparse/test_elementwise_op.py
+1
-1
tests/python/pytorch/sparse/test_elementwise_op_sp.py
tests/python/pytorch/sparse/test_elementwise_op_sp.py
+22
-1
tests/python/pytorch/sparse/test_softmax.py
tests/python/pytorch/sparse/test_softmax.py
+2
-2
No files found.
dgl_sparse/include/sparse/elementwise_op.h
View file @
7c465d20
...
@@ -14,20 +14,20 @@ namespace sparse {
...
@@ -14,20 +14,20 @@ namespace sparse {
/**
/**
* @brief Adds two sparse matrices possibly with different sparsities.
* @brief Adds two sparse matrices possibly with different sparsities.
*
*
* @param
A
SparseMatrix
* @param
lhs_mat
SparseMatrix
* @param
B
SparseMatrix
* @param
rhs_mat
SparseMatrix
*
*
* @return SparseMatrix
* @return SparseMatrix
*/
*/
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpAdd
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpAdd
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
A
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
B
);
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
);
/**
/**
* @brief Multiplies two sparse matrices possibly with different sparsities.
* @brief Multiplies two sparse matrices possibly with different sparsities.
*
*
* @param
A
SparseMatrix
* @param
lhs_mat
SparseMatrix
* @param
B
SparseMatrix
* @param
rhs_mat
SparseMatrix
*
*
* @return SparseMatrix
* @return SparseMatrix
*/
*/
...
@@ -35,6 +35,18 @@ c10::intrusive_ptr<SparseMatrix> SpSpMul(
...
@@ -35,6 +35,18 @@ c10::intrusive_ptr<SparseMatrix> SpSpMul(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
);
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
);
/**
* @brief Divides two sparse matrices with the same sparsity.
*
* @param lhs_mat SparseMatrix
* @param rhs_mat SparseMatrix
*
* @return SparseMatrix
*/
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpDiv
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
);
}
// namespace sparse
}
// namespace sparse
}
// namespace dgl
}
// namespace dgl
...
...
dgl_sparse/include/sparse/sparse_format.h
View file @
7c465d20
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include <torch/script.h>
#include <torch/script.h>
#include <memory>
#include <memory>
#include <utility>
namespace
dgl
{
namespace
dgl
{
namespace
sparse
{
namespace
sparse
{
...
@@ -113,6 +114,13 @@ std::shared_ptr<CSR> DiagToCSC(
...
@@ -113,6 +114,13 @@ std::shared_ptr<CSR> DiagToCSC(
/** @brief COO transposition. */
/** @brief COO transposition. */
std
::
shared_ptr
<
COO
>
COOTranspose
(
const
std
::
shared_ptr
<
COO
>&
coo
);
std
::
shared_ptr
<
COO
>
COOTranspose
(
const
std
::
shared_ptr
<
COO
>&
coo
);
/**
* @brief Sort the COO matrix by row and column indices.
* @return A pair of the sorted COO matrix and the permutation indices.
*/
std
::
pair
<
std
::
shared_ptr
<
COO
>
,
torch
::
Tensor
>
COOSort
(
const
std
::
shared_ptr
<
COO
>&
coo
);
}
// namespace sparse
}
// namespace sparse
}
// namespace dgl
}
// namespace dgl
...
...
dgl_sparse/src/elemenwise_op.cc
View file @
7c465d20
...
@@ -19,17 +19,18 @@ namespace sparse {
...
@@ -19,17 +19,18 @@ namespace sparse {
using
namespace
torch
::
autograd
;
using
namespace
torch
::
autograd
;
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpAdd
(
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpAdd
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
A
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
B
)
{
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
)
{
ElementwiseOpSanityCheck
(
A
,
B
);
ElementwiseOpSanityCheck
(
lhs_mat
,
rhs_mat
);
if
(
A
->
HasDiag
()
&&
B
->
HasDiag
())
{
if
(
lhs_mat
->
HasDiag
()
&&
rhs_mat
->
HasDiag
())
{
return
SparseMatrix
::
FromDiagPointer
(
return
SparseMatrix
::
FromDiagPointer
(
A
->
DiagPtr
(),
A
->
value
()
+
B
->
value
(),
A
->
shape
());
lhs_mat
->
DiagPtr
(),
lhs_mat
->
value
()
+
rhs_mat
->
value
(),
lhs_mat
->
shape
());
}
}
auto
torch_
A
=
COOToTorchCOO
(
A
->
COOPtr
(),
A
->
value
());
auto
torch_
lhs
=
COOToTorchCOO
(
lhs_mat
->
COOPtr
(),
lhs_mat
->
value
());
auto
torch_
B
=
COOToTorchCOO
(
B
->
COOPtr
(),
B
->
value
());
auto
torch_
rhs
=
COOToTorchCOO
(
rhs_mat
->
COOPtr
(),
rhs_mat
->
value
());
auto
sum
=
(
torch_
A
+
torch_
B
).
coalesce
();
auto
sum
=
(
torch_
lhs
+
torch_
rhs
).
coalesce
();
return
SparseMatrix
::
FromCOO
(
sum
.
indices
(),
sum
.
values
(),
A
->
shape
());
return
SparseMatrix
::
FromCOO
(
sum
.
indices
(),
sum
.
values
(),
lhs_mat
->
shape
());
}
}
class
SpSpMulAutoGrad
:
public
Function
<
SpSpMulAutoGrad
>
{
class
SpSpMulAutoGrad
:
public
Function
<
SpSpMulAutoGrad
>
{
...
@@ -117,5 +118,34 @@ c10::intrusive_ptr<SparseMatrix> SpSpMul(
...
@@ -117,5 +118,34 @@ c10::intrusive_ptr<SparseMatrix> SpSpMul(
return
SparseMatrix
::
FromCOO
(
indices
,
val
,
lhs_mat
->
shape
());
return
SparseMatrix
::
FromCOO
(
indices
,
val
,
lhs_mat
->
shape
());
}
}
c10
::
intrusive_ptr
<
SparseMatrix
>
SpSpDiv
(
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
lhs_mat
,
const
c10
::
intrusive_ptr
<
SparseMatrix
>&
rhs_mat
)
{
ElementwiseOpSanityCheck
(
lhs_mat
,
rhs_mat
);
if
(
lhs_mat
->
HasDiag
()
&&
rhs_mat
->
HasDiag
())
{
return
SparseMatrix
::
FromDiagPointer
(
lhs_mat
->
DiagPtr
(),
lhs_mat
->
value
()
/
rhs_mat
->
value
(),
lhs_mat
->
shape
());
}
std
::
shared_ptr
<
COO
>
sorted_lhs
,
sorted_rhs
;
torch
::
Tensor
lhs_sorted_perm
,
rhs_sorted_perm
;
std
::
tie
(
sorted_lhs
,
lhs_sorted_perm
)
=
COOSort
(
lhs_mat
->
COOPtr
());
std
::
tie
(
sorted_rhs
,
rhs_sorted_perm
)
=
COOSort
(
rhs_mat
->
COOPtr
());
TORCH_CHECK
(
!
lhs_mat
->
HasDuplicate
()
&&
!
rhs_mat
->
HasDuplicate
(),
"Only support SpSpDiv on sparse matrices without duplicate values"
)
TORCH_CHECK
(
torch
::
equal
(
sorted_lhs
->
indices
,
sorted_rhs
->
indices
),
"Cannot divide two COO matrices with different sparsities."
);
// This is to make sure the return matrix is in the same order as the lhs_mat
auto
lhs_sorted_rperm
=
lhs_sorted_perm
.
argsort
();
auto
rhs_perm_on_lhs
=
rhs_sorted_perm
.
index_select
(
0
,
lhs_sorted_rperm
);
auto
lhs_value
=
lhs_mat
->
value
();
auto
rhs_value
=
rhs_mat
->
value
().
index_select
(
0
,
rhs_perm_on_lhs
);
auto
ret_val
=
lhs_value
/
rhs_value
;
return
SparseMatrix
::
FromCOOPointer
(
lhs_mat
->
COOPtr
(),
ret_val
,
lhs_mat
->
shape
());
}
}
// namespace sparse
}
// namespace sparse
}
// namespace dgl
}
// namespace dgl
dgl_sparse/src/python_binding.cc
View file @
7c465d20
...
@@ -40,6 +40,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
...
@@ -40,6 +40,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
.
def
(
"from_diag"
,
&
SparseMatrix
::
FromDiag
)
.
def
(
"from_diag"
,
&
SparseMatrix
::
FromDiag
)
.
def
(
"spsp_add"
,
&
SpSpAdd
)
.
def
(
"spsp_add"
,
&
SpSpAdd
)
.
def
(
"spsp_mul"
,
&
SpSpMul
)
.
def
(
"spsp_mul"
,
&
SpSpMul
)
.
def
(
"spsp_div"
,
&
SpSpDiv
)
.
def
(
"reduce"
,
&
Reduce
)
.
def
(
"reduce"
,
&
Reduce
)
.
def
(
"sum"
,
&
ReduceSum
)
.
def
(
"sum"
,
&
ReduceSum
)
.
def
(
"smean"
,
&
ReduceMean
)
.
def
(
"smean"
,
&
ReduceMean
)
...
...
dgl_sparse/src/sparse_format.cc
View file @
7c465d20
...
@@ -140,5 +140,17 @@ std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {
...
@@ -140,5 +140,17 @@ std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {
return
COOFromOldDGLCOO
(
dgl_coo_tr
);
return
COOFromOldDGLCOO
(
dgl_coo_tr
);
}
}
std
::
pair
<
std
::
shared_ptr
<
COO
>
,
torch
::
Tensor
>
COOSort
(
const
std
::
shared_ptr
<
COO
>&
coo
)
{
auto
encoded_coo
=
coo
->
indices
.
index
({
0
})
*
coo
->
num_cols
+
coo
->
indices
.
index
({
1
});
torch
::
Tensor
sorted
,
perm
;
std
::
tie
(
sorted
,
perm
)
=
encoded_coo
.
sort
();
auto
sorted_coo
=
std
::
make_shared
<
COO
>
(
COO
{
coo
->
num_rows
,
coo
->
num_cols
,
coo
->
indices
.
index_select
(
1
,
perm
),
true
,
true
});
return
{
sorted_coo
,
perm
};
}
}
// namespace sparse
}
// namespace sparse
}
// namespace dgl
}
// namespace dgl
python/dgl/sparse/elementwise_op_sp.py
View file @
7c465d20
...
@@ -3,7 +3,7 @@ from typing import Union
...
@@ -3,7 +3,7 @@ from typing import Union
import
torch
import
torch
from
.sparse_matrix
import
diag
,
SparseMatrix
,
val_like
from
.sparse_matrix
import
SparseMatrix
,
val_like
from
.utils
import
is_scalar
,
Scalar
from
.utils
import
is_scalar
,
Scalar
...
@@ -21,6 +21,13 @@ def spsp_mul(A, B):
...
@@ -21,6 +21,13 @@ def spsp_mul(A, B):
)
)
def
spsp_div
(
A
,
B
):
"""Invoke C++ sparse library for division"""
return
SparseMatrix
(
torch
.
ops
.
dgl_sparse
.
spsp_div
(
A
.
c_sparse_matrix
,
B
.
c_sparse_matrix
)
)
def
sp_add
(
A
:
SparseMatrix
,
B
:
SparseMatrix
)
->
SparseMatrix
:
def
sp_add
(
A
:
SparseMatrix
,
B
:
SparseMatrix
)
->
SparseMatrix
:
"""Elementwise addition
"""Elementwise addition
...
@@ -141,8 +148,9 @@ def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
...
@@ -141,8 +148,9 @@ def sp_mul(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
def
sp_div
(
A
:
SparseMatrix
,
B
:
Union
[
SparseMatrix
,
Scalar
])
->
SparseMatrix
:
def
sp_div
(
A
:
SparseMatrix
,
B
:
Union
[
SparseMatrix
,
Scalar
])
->
SparseMatrix
:
"""Elementwise division
"""Elementwise division
If :attr:`B` is a sparse matrix, both :attr:`A` and :attr:`B` must be
If :attr:`B` is a sparse matrix, both :attr:`A` and :attr:`B` must have the
diagonal matrices.
same sparsity. And the returned matrix has the same order of non-zero
entries as :attr:`A`.
Parameters
Parameters
----------
----------
...
@@ -169,15 +177,7 @@ def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
...
@@ -169,15 +177,7 @@ def sp_div(A: SparseMatrix, B: Union[SparseMatrix, Scalar]) -> SparseMatrix:
"""
"""
if
is_scalar
(
B
):
if
is_scalar
(
B
):
return
val_like
(
A
,
A
.
val
/
B
)
return
val_like
(
A
,
A
.
val
/
B
)
if
A
.
is_diag
()
and
B
.
is_diag
():
return
spsp_div
(
A
,
B
)
assert
A
.
shape
==
B
.
shape
,
(
f
"The shape of diagonal matrix A
{
A
.
shape
}
and B
{
B
.
shape
}
must"
f
"match for elementwise division."
)
return
diag
(
A
.
val
/
B
.
val
,
A
.
shape
)
# Python falls back to B.__rtruediv__(A) then TypeError when NotImplemented
# is returned.
return
NotImplemented
def
sp_power
(
A
:
SparseMatrix
,
scalar
:
Scalar
)
->
SparseMatrix
:
def
sp_power
(
A
:
SparseMatrix
,
scalar
:
Scalar
)
->
SparseMatrix
:
...
...
tests/python/pytorch/sparse/test_elementwise_op.py
View file @
7c465d20
...
@@ -225,7 +225,7 @@ def test_sub_sparse_diag(val_shape):
...
@@ -225,7 +225,7 @@ def test_sub_sparse_diag(val_shape):
assert
torch
.
allclose
(
dense_diff
,
-
diff4
)
assert
torch
.
allclose
(
dense_diff
,
-
diff4
)
@
pytest
.
mark
.
parametrize
(
"op"
,
[
"truediv"
,
"pow"
])
@
pytest
.
mark
.
parametrize
(
"op"
,
[
"pow"
])
def
test_error_op_sparse_diag
(
op
):
def
test_error_op_sparse_diag
(
op
):
ctx
=
F
.
ctx
()
ctx
=
F
.
ctx
()
row
=
torch
.
tensor
([
1
,
0
,
2
]).
to
(
ctx
)
row
=
torch
.
tensor
([
1
,
0
,
2
]).
to
(
ctx
)
...
...
tests/python/pytorch/sparse/test_elementwise_op_sp.py
View file @
7c465d20
...
@@ -4,7 +4,7 @@ import backend as F
...
@@ -4,7 +4,7 @@ import backend as F
import
pytest
import
pytest
import
torch
import
torch
from
dgl.sparse
import
from_coo
,
mul
,
power
,
val_like
from
dgl.sparse
import
div
,
from_coo
,
mul
,
power
,
spmatrix
,
val_like
from
.utils
import
(
from
.utils
import
(
rand_coo
,
rand_coo
,
...
@@ -134,3 +134,24 @@ def test_spspmul(create_func1, create_func2, shape, nnz1, nnz2, nz_dim):
...
@@ -134,3 +134,24 @@ def test_spspmul(create_func1, create_func2, shape, nnz1, nnz2, nz_dim):
assert
torch
.
allclose
(
assert
torch
.
allclose
(
val_like
(
B
,
B
.
val
.
grad
).
to_dense
(),
DB
.
grad
,
atol
=
1e-05
val_like
(
B
,
B
.
val
.
grad
).
to_dense
(),
DB
.
grad
,
atol
=
1e-05
)
)
@
pytest
.
mark
.
parametrize
(
"create_func"
,
[
rand_coo
,
rand_csr
,
rand_csc
,
rand_diag
]
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
5
,
5
),
(
5
,
3
)])
@
pytest
.
mark
.
parametrize
(
"nnz"
,
[
1
,
14
])
@
pytest
.
mark
.
parametrize
(
"nz_dim"
,
[
None
,
3
])
def
test_spspdiv
(
create_func
,
nnz
,
shape
,
nz_dim
):
dev
=
F
.
ctx
()
A
=
create_func
(
shape
,
nnz
,
dev
,
nz_dim
)
perm
=
torch
.
randperm
(
A
.
nnz
,
device
=
dev
)
rperm
=
torch
.
argsort
(
perm
)
B
=
spmatrix
(
A
.
indices
()[:,
perm
],
A
.
val
[
perm
],
A
.
shape
)
C
=
div
(
A
,
B
)
assert
not
C
.
has_duplicate
()
assert
torch
.
allclose
(
C
.
val
,
A
.
val
/
B
.
val
[
rperm
],
atol
=
1e-05
)
assert
torch
.
allclose
(
C
.
indices
(),
A
.
indices
(),
atol
=
1e-05
)
# No need to test backward here, since it is handled by Pytorch
tests/python/pytorch/sparse/test_softmax.py
View file @
7c465d20
...
@@ -35,9 +35,9 @@ def test_softmax(val_D, csr, dim):
...
@@ -35,9 +35,9 @@ def test_softmax(val_D, csr, dim):
g
=
dgl
.
graph
((
row
,
col
),
num_nodes
=
max
(
A
.
shape
))
g
=
dgl
.
graph
((
row
,
col
),
num_nodes
=
max
(
A
.
shape
))
val_g
=
val
.
clone
().
requires_grad_
()
val_g
=
val
.
clone
().
requires_grad_
()
score
=
dgl
.
nn
.
functional
.
edge_softmax
(
g
,
val_g
)
score
=
dgl
.
nn
.
functional
.
edge_softmax
(
g
,
val_g
)
assert
torch
.
allclose
(
A_max
.
val
,
score
)
assert
torch
.
allclose
(
A_max
.
val
,
score
,
atol
=
1e-05
)
grad
=
torch
.
randn_like
(
score
).
to
(
dev
)
grad
=
torch
.
randn_like
(
score
).
to
(
dev
)
A_max
.
val
.
backward
(
grad
)
A_max
.
val
.
backward
(
grad
)
score
.
backward
(
grad
)
score
.
backward
(
grad
)
assert
torch
.
allclose
(
A
.
val
.
grad
,
val_g
.
grad
)
assert
torch
.
allclose
(
A
.
val
.
grad
,
val_g
.
grad
,
atol
=
1e-05
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment