Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b57c56d9
Unverified
Commit
b57c56d9
authored
Apr 15, 2023
by
czkkkkkk
Committed by
GitHub
Apr 15, 2023
Browse files
[Sparse] Support broadcasting operators (#5544)
parent
7c465d20
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
191 additions
and
0 deletions
+191
-0
docs/source/api/python/dgl.sparse_v0.rst
docs/source/api/python/dgl.sparse_v0.rst
+12
-0
python/dgl/sparse/__init__.py
python/dgl/sparse/__init__.py
+1
-0
python/dgl/sparse/broadcast.py
python/dgl/sparse/broadcast.py
+133
-0
tests/python/pytorch/sparse/test_broadcast.py
tests/python/pytorch/sparse/test_broadcast.py
+45
-0
No files found.
docs/source/api/python/dgl.sparse_v0.rst
View file @
b57c56d9
...
...
@@ -148,3 +148,15 @@ Non-linear activation functions
:toctree: ../../generated/
softmax
Broadcast operators
````````
.. autosummary::
:toctree: ../../generated/
sp_broadcast_v
sp_add_v
sp_sub_v
sp_mul_v
sp_div_v
\ No newline at end of file
python/dgl/sparse/__init__.py
View file @
b57c56d9
...
...
@@ -5,6 +5,7 @@ import sys
import
torch
from
.._ffi
import
libinfo
from
.broadcast
import
*
from
.elementwise_op
import
*
from
.elementwise_op_sp
import
*
from
.matmul
import
*
...
...
python/dgl/sparse/broadcast.py
0 → 100644
View file @
b57c56d9
"""DGL broadcast operator module."""
import
operator
import
torch
from
.sparse_matrix
import
SparseMatrix
,
val_like
def
sp_broadcast_v
(
A
:
SparseMatrix
,
v
:
torch
.
Tensor
,
op
:
str
)
->
SparseMatrix
:
"""Broadcast operator for sparse matrix and vector.
:attr:`v` is broadcasted to the shape of :attr:`A` and then the operator is
applied on the non-zero values of :attr:`A`.
There are two cases regarding the shape of v:
1. :attr:`v` is a vector of shape ``(1, A.shape[1])`` or ``(A.shape[1])``.
In this case, :attr:`v` is broadcasted on the row dimension of :attr:`A`.
2. :attr:`v` is a vector of shape ``(A.shape[0], 1)``. In this case,
:attr:`v` is broadcasted on the column dimension of :attr:`A`.
If ``A.val`` takes shape ``(nnz, D)``, then :attr:`v` will be broadcasted on
the ``D`` dimension.
Parameters
----------
A: SparseMatrix
Sparse matrix
v: torch.Tensor
Vector
op: str
Operator in ["add", "sub", "mul", "truediv"]
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
>>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])
>>> val = torch.tensor([10, 20, 30])
>>> A = dglsp.spmatrix(indices, val, shape=(3, 4))
>>> v = torch.tensor([1, 2, 3, 4])
>>> dglsp.sp_broadcast_v(A, v, "add")
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([11, 24, 33]),
shape=(3, 4), nnz=3)
>>> v = torch.tensor([1, 2, 3]).view(-1, 1)
>>> dglsp.sp_broadcast_v(A, v, "add")
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([12, 21, 33]),
shape=(3, 4), nnz=3)
>>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])
>>> val = torch.tensor([[10, 20], [30, 40], [50, 60]])
>>> A = dglsp.spmatrix(indices, val, shape=(3, 4))
>>> v = torch.tensor([1, 2, 3]).view(-1, 1)
>>> dglsp.sp_broadcast_v(A, v, "sub")
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([[ 8, 18],
[29, 39],
[47, 57]]),
shape=(3, 4), nnz=3, val_size=(2,))
"""
op
=
getattr
(
operator
,
op
)
if
v
.
dim
()
==
1
:
v
=
v
.
view
(
1
,
-
1
)
shape_error_message
=
(
f
"Dimension mismatch for broadcasting. Got A.shape =
{
A
.
shape
}
and"
f
"v.shape =
{
v
.
shape
}
."
)
assert
v
.
dim
()
<=
2
and
(
1
in
v
.
shape
),
shape_error_message
broadcast_dim
=
None
# v can be broadcasted to A if exactly one dimension of v is 1 and the other
# is the same as A.
for
d
,
(
dim1
,
dim2
)
in
enumerate
(
zip
(
A
.
shape
,
v
.
shape
)):
assert
dim2
in
(
1
,
dim1
),
shape_error_message
if
dim1
!=
dim2
:
assert
broadcast_dim
is
None
,
shape_error_message
broadcast_dim
=
d
# A and v has the same shape of (1, *) or (*, 1).
if
broadcast_dim
is
None
:
broadcast_dim
=
0
if
A
.
shape
[
0
]
==
1
else
1
if
broadcast_dim
==
0
:
v
=
v
.
view
(
-
1
)[
A
.
col
]
else
:
v
=
v
.
view
(
-
1
)[
A
.
row
]
if
A
.
val
.
dim
()
>
1
:
v
=
v
.
view
(
-
1
,
1
)
ret_val
=
op
(
A
.
val
,
v
)
return
val_like
(
A
,
ret_val
)
def
sp_add_v
(
A
:
SparseMatrix
,
v
:
torch
.
Tensor
)
->
SparseMatrix
:
"""Broadcast addition for sparse matrix and vector.
See the definition of :func:`sp_broadcast_v` for details.
"""
return
sp_broadcast_v
(
A
,
v
,
"add"
)
def
sp_sub_v
(
A
:
SparseMatrix
,
v
:
torch
.
Tensor
)
->
SparseMatrix
:
"""Broadcast substraction for sparse matrix and vector.
See the definition of :func:`sp_broadcast_v` for details.
"""
return
sp_broadcast_v
(
A
,
v
,
"sub"
)
def
sp_mul_v
(
A
:
SparseMatrix
,
v
:
torch
.
Tensor
)
->
SparseMatrix
:
"""Broadcast multiply for sparse matrix and vector.
See the definition of :func:`sp_broadcast_v` for details.
"""
return
sp_broadcast_v
(
A
,
v
,
"mul"
)
def
sp_div_v
(
A
:
SparseMatrix
,
v
:
torch
.
Tensor
)
->
SparseMatrix
:
"""Broadcast division for sparse matrix and vector.
See the definition of :func:`sp_broadcast_v` for details.
"""
return
sp_broadcast_v
(
A
,
v
,
"truediv"
)
tests/python/pytorch/sparse/test_broadcast.py
0 → 100644
View file @
b57c56d9
import
operator
import
backend
as
F
import
pytest
import
torch
from
dgl.sparse
import
sp_broadcast_v
from
.utils
import
rand_coo
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
3
,
4
),
(
1
,
5
),
(
5
,
1
)])
@
pytest
.
mark
.
parametrize
(
"nnz"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"nz_dim"
,
[
None
,
2
])
@
pytest
.
mark
.
parametrize
(
"op"
,
[
"add"
,
"sub"
,
"mul"
,
"truediv"
])
def
test_sp_broadcast_v
(
shape
,
nnz
,
nz_dim
,
op
):
dev
=
F
.
ctx
()
A
=
rand_coo
(
shape
,
nnz
,
dev
,
nz_dim
)
v
=
torch
.
randn
(
A
.
shape
[
1
],
device
=
dev
)
res1
=
sp_broadcast_v
(
A
,
v
,
op
)
if
A
.
val
.
dim
()
==
1
:
rhs
=
v
[
A
.
col
]
else
:
rhs
=
v
[
A
.
col
].
view
(
-
1
,
1
)
res2
=
getattr
(
operator
,
op
)(
A
.
val
,
rhs
)
assert
torch
.
allclose
(
res1
.
val
,
res2
)
v
=
torch
.
randn
(
1
,
A
.
shape
[
1
],
device
=
dev
)
res1
=
sp_broadcast_v
(
A
,
v
,
op
)
if
A
.
val
.
dim
()
==
1
:
rhs
=
v
.
view
(
-
1
)[
A
.
col
]
else
:
rhs
=
v
.
view
(
-
1
)[
A
.
col
].
view
(
-
1
,
1
)
res2
=
getattr
(
operator
,
op
)(
A
.
val
,
rhs
)
assert
torch
.
allclose
(
res1
.
val
,
res2
)
v
=
torch
.
randn
(
A
.
shape
[
0
],
1
,
device
=
dev
)
res1
=
sp_broadcast_v
(
A
,
v
,
op
)
if
A
.
val
.
dim
()
==
1
:
rhs
=
v
.
view
(
-
1
)[
A
.
row
]
else
:
rhs
=
v
.
view
(
-
1
)[
A
.
row
].
view
(
-
1
,
1
)
res2
=
getattr
(
operator
,
op
)(
A
.
val
,
rhs
)
assert
torch
.
allclose
(
res1
.
val
,
res2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment