Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0afc3cf8
Unverified
Commit
0afc3cf8
authored
Aug 17, 2020
by
Jinjing Zhou
Committed by
GitHub
Aug 17, 2020
Browse files
[NN] Add ChebConv for Tensorflow (#2038)
* add chebconv * fix * lint * fix lint * docs
parent
b3538802
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
338 additions
and
1 deletion
+338
-1
docs/source/api/python/nn.tensorflow.rst
docs/source/api/python/nn.tensorflow.rst
+7
-0
python/dgl/nn/tensorflow/conv/__init__.py
python/dgl/nn/tensorflow/conv/__init__.py
+2
-0
python/dgl/nn/tensorflow/conv/chebconv.py
python/dgl/nn/tensorflow/conv/chebconv.py
+167
-0
python/dgl/nn/tensorflow/conv/densechebconv.py
python/dgl/nn/tensorflow/conv/densechebconv.py
+137
-0
tests/tensorflow/test_nn.py
tests/tensorflow/test_nn.py
+25
-1
No files found.
docs/source/api/python/nn.tensorflow.rst
View file @
0afc3cf8
...
...
@@ -45,6 +45,13 @@ SAGEConv
:members: forward
:show-inheritance:
ChebConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.tensorflow.conv.ChebConv
:members: forward
:show-inheritance:
SGConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
python/dgl/nn/tensorflow/conv/__init__.py
View file @
0afc3cf8
...
...
@@ -6,3 +6,5 @@ from .ginconv import GINConv
from
.sageconv
import
SAGEConv
from
.sgconv
import
SGConv
from
.appnpconv
import
APPNPConv
from
.chebconv
import
ChebConv
from
.densechebconv
import
DenseChebConv
python/dgl/nn/tensorflow/conv/chebconv.py
0 → 100644
View file @
0afc3cf8
"""Tensorflow Module for Chebyshev Spectral Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
tensorflow
as
tf
from
tensorflow.keras
import
layers
import
numpy
as
np
from
....base
import
dgl_warning
from
....
import
laplacian_lambda_max
,
broadcast_nodes
,
function
as
fn
class
ChebConv
(
layers
.
Layer
):
r
"""
Description
-----------
Chebyshev Spectral Graph Convolution layer from paper `Convolutional
Neural Networks on Graphs with Fast Localized Spectral Filtering
<https://arxiv.org/pdf/1606.09375.pdf>`__.
.. math::
h_i^{l+1} &= \sum_{k=0}^{K-1} W^{k, l}z_i^{k, l}
Z^{0, l} &= H^{l}
Z^{1, l} &= \tilde{L} \cdot H^{l}
Z^{k, l} &= 2 \cdot \tilde{L} \cdot Z^{k-1, l} - Z^{k-2, l}
\tilde{L} &= 2\left(I - \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\right)/\lambda_{max} - I
where :math:`\tilde{A}` is :math:`A` + :math:`I`, :math:`W` is learnable weight.
Parameters
----------
in_feats: int
Dimension of input features; i.e, the number of dimensions of :math:`h_i^{(l)}`.
out_feats: int
Dimension of output features :math:`h_i^{(l+1)}`.
k : int
Chebyshev filter size :math:`K`.
activation : function, optional
Activation function. Default ``ReLu``.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
Note
----
ChebConv only support DGLGraph as input for now. Heterograph will report error. To be fixed.
Example
-------
>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import ChebConv
>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = tf.ones(6, 10)
>>> conv = ChebConv(10, 2, 2)
>>> res = conv(g, feat)
>>> res
tensor([[ 0.6163, -0.1809],
[ 0.6163, -0.1809],
[ 0.6163, -0.1809],
[ 0.9698, -1.5053],
[ 0.3664, 0.7556],
[-0.2370, 3.0164]])
"""
def
__init__
(
self
,
in_feats
,
out_feats
,
k
,
activation
=
tf
.
nn
.
relu
,
bias
=
True
):
super
(
ChebConv
,
self
).
__init__
()
self
.
_k
=
k
self
.
_in_feats
=
in_feats
self
.
_out_feats
=
out_feats
self
.
activation
=
activation
self
.
linear
=
layers
.
Dense
(
out_feats
,
use_bias
=
bias
)
def
call
(
self
,
graph
,
feat
,
lambda_max
=
None
):
r
"""
Description
-----------
Compute ChebNet layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
lambda_max : list or tensor or None, optional.
A list(tensor) with length :math:`B`, stores the largest eigenvalue
of the normalized laplacian of each individual graph in ``graph``,
where :math:`B` is the batch size of the input graph. Default: None.
If None, this method would compute the list by calling
``dgl.laplacian_lambda_max``.
Returns
-------
tf.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
def
unnLaplacian
(
feat
,
D_invsqrt
,
graph
):
""" Operation Feat * D^-1/2 A D^-1/2 """
graph
.
ndata
[
'h'
]
=
feat
*
D_invsqrt
graph
.
update_all
(
fn
.
copy_u
(
'h'
,
'm'
),
fn
.
sum
(
'm'
,
'h'
))
return
graph
.
ndata
.
pop
(
'h'
)
*
D_invsqrt
with
graph
.
local_scope
():
in_degrees
=
tf
.
clip_by_value
(
tf
.
cast
(
graph
.
in_degrees
(),
tf
.
float32
),
clip_value_min
=
1
,
clip_value_max
=
np
.
inf
)
D_invsqrt
=
tf
.
expand_dims
(
tf
.
pow
(
in_degrees
,
-
0.5
),
axis
=-
1
)
if
lambda_max
is
None
:
try
:
lambda_max
=
laplacian_lambda_max
(
graph
)
except
BaseException
:
# if the largest eigenvalue is not found
dgl_warning
(
"Largest eigonvalue not found, using default value 2 for lambda_max"
,
RuntimeWarning
)
lambda_max
=
tf
.
constant
(
2
,
dtype
=
tf
.
float32
)
if
isinstance
(
lambda_max
,
list
):
lambda_max
=
tf
.
constant
(
lambda_max
,
dtype
=
tf
.
float32
)
if
lambda_max
.
ndim
==
1
:
lambda_max
=
tf
.
expand_dims
(
lambda_max
,
axis
=-
1
)
# (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1)
lambda_max
=
broadcast_nodes
(
graph
,
lambda_max
)
re_norm
=
2.
/
lambda_max
# X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
Xt
=
X_0
=
feat
# X_1(f)
if
self
.
_k
>
1
:
h
=
unnLaplacian
(
X_0
,
D_invsqrt
,
graph
)
X_1
=
-
re_norm
*
h
+
X_0
*
(
re_norm
-
1
)
# Concatenate Xt and X_1
Xt
=
tf
.
concat
((
Xt
,
X_1
),
1
)
# Xi(x), i = 2...k
for
_
in
range
(
2
,
self
.
_k
):
h
=
unnLaplacian
(
X_1
,
D_invsqrt
,
graph
)
X_i
=
-
2
*
re_norm
*
h
+
X_1
*
2
*
(
re_norm
-
1
)
-
X_0
# Concatenate Xt and X_i
Xt
=
tf
.
concat
((
Xt
,
X_i
),
1
)
X_1
,
X_0
=
X_i
,
X_1
# linear projection
h
=
self
.
linear
(
Xt
)
# activation
if
self
.
activation
:
h
=
self
.
activation
(
h
)
return
h
python/dgl/nn/tensorflow/conv/densechebconv.py
0 → 100644
View file @
0afc3cf8
"""Tensorflow Module for DenseChebConv"""
# pylint: disable= no-member, arguments-differ, invalid-name
import
tensorflow
as
tf
from
tensorflow.keras
import
layers
import
numpy
as
np
class
DenseChebConv
(
layers
.
Layer
):
r
"""
Description
-----------
Chebyshev Spectral Graph Convolution layer from paper `Convolutional
Neural Networks on Graphs with Fast Localized Spectral Filtering
<https://arxiv.org/pdf/1606.09375.pdf>`__.
We recommend to use this module when applying ChebConv on dense graphs.
Parameters
----------
in_feats: int
Dimension of input features :math:`h_i^{(l)}`.
out_feats: int
Dimension of output features :math:`h_i^{(l+1)}`.
k : int
Chebyshev filter size.
activation : function, optional
Activation function, default is ReLu.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
Example
-------
>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import DenseChebConv
>>>
>>> feat = tf.ones(6, 10)
>>> adj = tf.tensor([[0., 0., 1., 0., 0., 0.],
... [1., 0., 0., 0., 0., 0.],
... [0., 1., 0., 0., 0., 0.],
... [0., 0., 1., 0., 0., 1.],
... [0., 0., 0., 1., 0., 0.],
... [0., 0., 0., 0., 0., 0.]])
>>> conv = DenseChebConv(10, 2, 2)
>>> res = conv(adj, feat)
>>> res
tensor([[-3.3516, -2.4797],
[-3.3516, -2.4797],
[-3.3516, -2.4797],
[-4.5192, -3.0835],
[-2.5259, -2.0527],
[-0.5327, -1.0219]])
See also
--------
`ChebConv <https://docs.dgl.ai/api/python/nn.tensorflow.html#chebconv>`__
"""
def
__init__
(
self
,
in_feats
,
out_feats
,
k
,
bias
=
True
):
super
(
DenseChebConv
,
self
).
__init__
()
self
.
_in_feats
=
in_feats
self
.
_out_feats
=
out_feats
self
.
_k
=
k
# keras initializer assume last two dims as fan_in and fan_out
xinit
=
tf
.
keras
.
initializers
.
glorot_normal
()
self
.
W
=
tf
.
Variable
(
initial_value
=
xinit
(
shape
=
(
k
,
in_feats
,
out_feats
),
dtype
=
'float32'
),
trainable
=
True
)
if
bias
:
zeroinit
=
tf
.
keras
.
initializers
.
zeros
()
self
.
bias
=
tf
.
Variable
(
initial_value
=
zeroinit
(
shape
=
(
out_feats
),
dtype
=
'float32'
),
trainable
=
True
)
else
:
self
.
bias
=
None
def
call
(
self
,
adj
,
feat
,
lambda_max
=
None
):
r
"""
Description
-----------
Compute (Dense) Chebyshev Spectral Graph Convolution layer.
Parameters
----------
adj : tf.Tensor
The adjacency matrix of the graph to apply Graph Convolution on,
should be of shape :math:`(N, N)`, where a row represents the destination
and a column represents the source.
feat : tf.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
lambda_max : float or None, optional
A float value indicates the largest eigenvalue of given graph.
Default: None.
Returns
-------
tf.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
A
=
adj
num_nodes
=
A
.
shape
[
0
]
in_degree
=
1
/
tf
.
sqrt
(
tf
.
clip_by_value
(
tf
.
reduce_sum
(
A
,
1
),
clip_value_min
=
1
,
clip_value_max
=
np
.
inf
))
D_invsqrt
=
tf
.
linalg
.
diag
(
in_degree
)
I
=
tf
.
eye
(
num_nodes
)
L
=
I
-
D_invsqrt
@
A
@
D_invsqrt
if
lambda_max
is
None
:
lambda_
=
tf
.
linalg
.
eig
(
L
)[
0
][:,
0
]
lambda_max
=
tf
.
reduce_max
(
lambda_
)
L_hat
=
2
*
L
/
lambda_max
-
I
Z
=
[
tf
.
eye
(
num_nodes
)]
for
i
in
range
(
1
,
self
.
_k
):
if
i
==
1
:
Z
.
append
(
L_hat
)
else
:
Z
.
append
(
2
*
L_hat
@
Z
[
-
1
]
-
Z
[
-
2
])
Zs
=
tf
.
stack
(
Z
,
0
)
# (k, n, n)
Zh
=
(
Zs
@
tf
.
expand_dims
(
feat
,
axis
=
0
)
@
self
.
W
)
Zh
=
tf
.
reduce_sum
(
Zh
,
0
)
if
self
.
bias
is
not
None
:
Zh
=
Zh
+
self
.
bias
return
Zh
tests/tensorflow/test_nn.py
View file @
0afc3cf8
...
...
@@ -481,6 +481,30 @@ def test_hetero_conv(agg, idtype):
assert
mod3
.
carg1
==
0
assert
mod3
.
carg2
==
1
def
test_dense_cheb_conv
():
for
k
in
range
(
3
,
4
):
ctx
=
F
.
ctx
()
g
=
dgl
.
DGLGraph
(
sp
.
sparse
.
random
(
100
,
100
,
density
=
0.1
,
random_state
=
42
))
g
=
g
.
to
(
ctx
)
adj
=
tf
.
sparse
.
to_dense
(
tf
.
sparse
.
reorder
(
g
.
adjacency_matrix
(
ctx
=
ctx
)))
cheb
=
nn
.
ChebConv
(
5
,
2
,
k
,
None
,
bias
=
True
)
dense_cheb
=
nn
.
DenseChebConv
(
5
,
2
,
k
,
bias
=
True
)
# init cheb modules
feat
=
F
.
ones
((
100
,
5
))
out_cheb
=
cheb
(
g
,
feat
,
[
2.0
])
dense_cheb
.
W
=
tf
.
reshape
(
cheb
.
linear
.
weights
[
0
],
(
k
,
5
,
2
))
if
cheb
.
linear
.
bias
is
not
None
:
dense_cheb
.
bias
=
cheb
.
linear
.
bias
out_dense_cheb
=
dense_cheb
(
adj
,
feat
,
2.0
)
print
(
out_cheb
-
out_dense_cheb
)
assert
F
.
allclose
(
out_cheb
,
out_dense_cheb
)
if
__name__
==
'__main__'
:
test_graph_conv
()
# test_set2set()
...
...
@@ -500,5 +524,5 @@ if __name__ == '__main__':
# test_gmm_conv()
# test_dense_graph_conv()
# test_dense_sage_conv()
#
test_dense_cheb_conv()
test_dense_cheb_conv
()
# test_sequential()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment