Unverified Commit b1e69105 authored by Tianjun Xiao's avatar Tianjun Xiao Committed by GitHub
Browse files

[Doc] API Doc update for mxnet and tf, remove some degree check (#2028)

* mx tf relconv

* use method instead of private attr

* src and dst have different fc for gat

* update edgeconv

* change sage and sgconv

* no degree check on gin

* add remainding API doc

* fix pylint

* infer fc_src and fc_dst, only one tensor for block

* fix pytest
parent 4f1da61b
......@@ -184,12 +184,30 @@ class GATConv(nn.Module):
The attention weights are using xavier initialization method.
"""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc.weight, gain=gain)
if hasattr(self, 'fc'):
nn.init.xavier_normal_(self.fc.weight, gain=gain)
else:
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat):
r"""
......@@ -236,8 +254,10 @@ class GATConv(nn.Module):
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc(h_dst).view(-1, self._num_heads, self._out_feats)
if not hasattr(self, 'fc_src'):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
......
......@@ -96,6 +96,20 @@ class GatedGraphConv(nn.Module):
init.xavier_normal_(linear.weight, gain=gain)
init.zeros_(linear.bias)
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat, etypes):
"""
......
......@@ -4,7 +4,6 @@ import torch as th
from torch import nn
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair
......@@ -32,28 +31,6 @@ class GINConv(nn.Module):
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Default: ``False``.
Notes
-----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zere-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv.
Example
-------
......@@ -63,31 +40,29 @@ class GINConv(nn.Module):
>>> from dgl.nn import GINConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> lin = th.nn.Linear(10, 10)
>>> conv = GINConv(lin, 'max')
>>> res = conv(g, feat)
>>> res
tensor([[ 1.2330, -0.1572, 0.0622, -3.1567, -2.2414, -0.7275, 0.6311, 1.0396,
1.7008, -1.2468],
[ 1.2330, -0.1572, 0.0622, -3.1567, -2.2414, -0.7275, 0.6311, 1.0396,
1.7008, -1.2468],
[ 1.2330, -0.1572, 0.0622, -3.1567, -2.2414, -0.7275, 0.6311, 1.0396,
1.7008, -1.2468],
[ 1.2330, -0.1572, 0.0622, -3.1567, -2.2414, -0.7275, 0.6311, 1.0396,
1.7008, -1.2468],
[ 1.2330, -0.1572, 0.0622, -3.1567, -2.2414, -0.7275, 0.6311, 1.0396,
1.7008, -1.2468],
[ 1.2330, -0.1572, 0.0622, -3.1567, -2.2414, -0.7275, 0.6311, 1.0396,
1.7008, -1.2468]], grad_fn=<AddmmBackward>)
tensor([[-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855,
0.8843, -0.8764],
[-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855,
0.8843, -0.8764],
[-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855,
0.8843, -0.8764],
[-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855,
0.8843, -0.8764],
[-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855,
0.8843, -0.8764],
[-0.1804, 0.0758, -0.5159, 0.3569, -0.1408, -0.1395, -0.2387, 0.7773,
0.5266, -0.4465]], grad_fn=<AddmmBackward>)
"""
def __init__(self,
apply_func,
aggregator_type,
init_eps=0,
learn_eps=False,
allow_zero_in_degree=False):
learn_eps=False):
super(GINConv, self).__init__()
self.apply_func = apply_func
self._aggregator_type = aggregator_type
......@@ -104,7 +79,6 @@ class GINConv(nn.Module):
self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
self._allow_zero_in_degree = allow_zero_in_degree
def forward(self, graph, feat):
r"""
......@@ -132,28 +106,8 @@ class GINConv(nn.Module):
:math:`D_{out}` is the output dimensionality of ``apply_func``.
If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality.
Raises
------
DGLError
If there are 0-in-degree nodes in the input graph, it will raise DGLError
since no message will be passed to those nodes. This will cause invalid output.
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any() and \
(self._aggregator_type not in ['sum', 'mean']):
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
......
......@@ -171,6 +171,20 @@ class GMMConv(nn.Module):
if self.bias is not None:
init.zeros_(self.bias.data)
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat, pseudo):
"""
......
......@@ -171,6 +171,20 @@ class GraphConv(nn.Module):
if self.bias is not None:
init.zeros_(self.bias)
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat, weight=None):
r"""
......
......@@ -4,7 +4,6 @@ import torch as th
from torch import nn
from torch.nn import init
from ....base import DGLError
from .... import function as fn
from ..utils import Identity
from ....utils import expand_as_pair
......@@ -48,28 +47,6 @@ class NNConv(nn.Module):
If True, use residual connection. Default: ``False``.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves.
Notes
-----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zere-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv.
Examples
--------
......@@ -117,8 +94,7 @@ class NNConv(nn.Module):
edge_func,
aggregator_type='mean',
residual=False,
bias=True,
allow_zero_in_degree=False):
bias=True):
super(NNConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
......@@ -143,7 +119,6 @@ class NNConv(nn.Module):
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self._allow_zero_in_degree = allow_zero_in_degree
self.reset_parameters()
def reset_parameters(self):
......@@ -186,18 +161,6 @@ class NNConv(nn.Module):
is the output feature size.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
feat_src, feat_dst = expand_as_pair(feat, graph)
# (n, d_in, 1)
......
......@@ -116,6 +116,20 @@ class SGConv(nn.Module):
if self.fc.bias is not None:
nn.init.zeros_(self.fc.bias)
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat):
r"""
......
......@@ -107,8 +107,9 @@ class HeteroGraphConv(nn.Module):
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
if hasattr(v, '_allow_zero_in_degree'):
v._allow_zero_in_degree = True
set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None)
if callable(set_allow_zero_in_degree_fn):
set_allow_zero_in_degree_fn(True)
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
......
......@@ -5,6 +5,7 @@ from tensorflow.keras import layers
import numpy as np
from .... import function as fn
from ....base import DGLError
from ....ops import edge_softmax
from ..utils import Identity
......@@ -12,7 +13,11 @@ from ..utils import Identity
class GATConv(layers.Layer):
r"""Apply `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__
r"""
Description
-----------
Apply `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__
over an input signal.
.. math::
......@@ -22,36 +27,117 @@ class GATConv(layers.Layer):
node :math:`j`:
.. math::
\alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l})
\alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})
e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)
e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)
Parameters
----------
in_feats : int, or a pair of ints
Input feature size.
in_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
ATConv can be applied on homogeneous graph and unidirectional
`bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int
Output feature size.
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Dropout rate on feature, defaults: ``0``.
Dropout rate on feature. Defaults: ``0``.
attn_drop : float, optional
Dropout rate on attention weight, defaults: ``0``.
Dropout rate on attention weight. Defaults: ``0``.
negative_slope : float, optional
LeakyReLU angle of negative slope.
LeakyReLU angle of negative slope. Defaults: ``0.2``.
residual : bool, optional
If True, use residual connection.
If True, use residual connection. Defaults: ``False``.
activation : callable activation function/layer or None, optional.
If not None, applies an activation function to the updated node features.
Default: ``None``.
"""
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
Notes
-----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zere-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import GATConv
>>>
>>> # Case 1: Homogeneous graph
>>> with tf.device("CPU:0"):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = tf.ones((6, 10))
>>> gatconv = GATConv(10, 2, num_heads=3)
>>> res = gatconv(g, feat)
>>> res
<tf.Tensor: shape=(6, 3, 2), dtype=float32, numpy=
array([[[ 0.75311995, -1.8093625 ],
[-0.12128812, -0.78072834],
[-0.49870574, -0.15074375]],
[[ 0.75311995, -1.8093625 ],
[-0.12128812, -0.78072834],
[-0.49870574, -0.15074375]],
[[ 0.75311995, -1.8093625 ],
[-0.12128812, -0.78072834],
[-0.49870574, -0.15074375]],
[[ 0.75311995, -1.8093626 ],
[-0.12128813, -0.78072834],
[-0.49870574, -0.15074375]],
[[ 0.75311995, -1.8093625 ],
[-0.12128812, -0.78072834],
[-0.49870574, -0.15074375]],
[[ 0.75311995, -1.8093625 ],
[-0.12128812, -0.78072834],
[-0.49870574, -0.15074375]]], dtype=float32)>
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.bipartite((u, v))
>>> with tf.device("CPU:0"):
>>> u_feat = tf.convert_to_tensor(np.random.rand(2, 5))
>>> v_feat = tf.convert_to_tensor(np.random.rand(4, 10))
>>> gatconv = GATConv((5,10), 2, 3)
>>> res = gatconv(g, (u_feat, v_feat))
>>> res
<tf.Tensor: shape=(4, 3, 2), dtype=float32, numpy=
array([[[-0.89649093, -0.74841046],
[ 0.5088224 , 0.10908248],
[ 0.55670375, -0.6811229 ]],
[[-0.7905004 , -0.1457274 ],
[ 0.2248168 , 0.93014705],
[ 0.12816726, -0.4093595 ]],
[[-0.85875374, -0.53382933],
[ 0.36841977, 0.51498866],
[ 0.31893706, -0.5303393 ]],
[[-0.89649093, -0.74841046],
[ 0.5088224 , 0.10908248],
[ 0.55670375, -0.6811229 ]]], dtype=float32)>
"""
def __init__(self,
in_feats,
out_feats,
......@@ -60,15 +146,23 @@ class GATConv(layers.Layer):
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None):
activation=None,
allow_zero_in_degree=False):
super(GATConv, self).__init__()
self._num_heads = num_heads
self._in_feats = in_feats
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
xinit = tf.keras.initializers.VarianceScaling(scale=np.sqrt(
2), mode="fan_avg", distribution="untruncated_normal")
self.fc = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
if isinstance(in_feats, tuple):
self.fc_src = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
self.fc_dst = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
else:
self.fc = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
self.attn_l = tf.Variable(initial_value=xinit(
shape=(1, num_heads, out_feats), dtype='float32'), trainable=True)
self.attn_r = tf.Variable(initial_value=xinit(
......@@ -87,8 +181,26 @@ class GATConv(layers.Layer):
# self.register_buffer('res_fc', None)
self.activation = activation
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def call(self, graph, feat):
r"""Compute graph attention network layer.
r"""
Description
-----------
Compute graph attention network layer.
Parameters
----------
......@@ -105,13 +217,34 @@ class GATConv(layers.Layer):
tf.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
Raises
------
DGLError
If there are 0-in-degree nodes in the input graph, it will raise DGLError
since no message will be passed to those nodes. This will cause invalid output.
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = tf.reshape(self.fc(h_src), (-1, self._num_heads, self._out_feats))
feat_dst = tf.reshape(self.fc(h_dst), (-1, self._num_heads, self._out_feats))
if not hasattr(self, 'fc_src'):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = tf.reshape(self.fc_src(h_src), (-1, self._num_heads, self._out_feats))
feat_dst = tf.reshape(self.fc_dst(h_dst), (-1, self._num_heads, self._out_feats))
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = tf.reshape(
......
......@@ -8,7 +8,11 @@ from ....utils import expand_as_pair
class GINConv(layers.Layer):
r"""Graph Isomorphism Network layer from paper `How Powerful are Graph
r"""
Description
-----------
Graph Isomorphism Network layer from paper `How Powerful are Graph
Neural Networks? <https://arxiv.org/pdf/1810.00826.pdf>`__.
.. math::
......@@ -26,7 +30,36 @@ class GINConv(layers.Layer):
init_eps : float, optional
Initial :math:`\epsilon` value, default: ``0``.
learn_eps : bool, optional
If True, :math:`\epsilon` will be a learnable parameter.
If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``.
Example
-------
>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import GINConv
>>>
>>> with tf.device("CPU:0"):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = tf.ones((6, 10))
>>> lin = tf.keras.layers.Dense(10)
>>> conv = GINConv(lin, 'max')
>>> res = conv(g, feat)
>>> res
<tf.Tensor: shape=(6, 10), dtype=float32, numpy=
array([[-0.1090256 , 1.9050574 , -0.30704725, -1.995831 , -0.36399186,
1.10414 , 2.4885745 , -0.35387516, 1.3568261 , 1.7267858 ],
[-0.1090256 , 1.9050574 , -0.30704725, -1.995831 , -0.36399186,
1.10414 , 2.4885745 , -0.35387516, 1.3568261 , 1.7267858 ],
[-0.1090256 , 1.9050574 , -0.30704725, -1.995831 , -0.36399186,
1.10414 , 2.4885745 , -0.35387516, 1.3568261 , 1.7267858 ],
[-0.1090256 , 1.9050574 , -0.30704725, -1.995831 , -0.36399186,
1.10414 , 2.4885745 , -0.35387516, 1.3568261 , 1.7267858 ],
[-0.1090256 , 1.9050574 , -0.30704725, -1.995831 , -0.36399186,
1.10414 , 2.4885745 , -0.35387516, 1.3568261 , 1.7267858 ],
[-0.0545128 , 0.9525287 , -0.15352362, -0.9979155 , -0.18199593,
0.55207 , 1.2442873 , -0.17693758, 0.67841303, 0.8633929 ]],
dtype=float32)>
"""
def __init__(self,
apply_func,
......@@ -47,13 +80,16 @@ class GINConv(layers.Layer):
self.eps = tf.Variable(initial_value=[init_eps], dtype=tf.float32, trainable=learn_eps)
def call(self, graph, feat):
r"""Compute Graph Isomorphism Network layer.
r"""
Description
-----------
Compute Graph Isomorphism Network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor or pair of tf.Tensor
If a tf.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
......
......@@ -117,8 +117,8 @@ class GraphConv(layers.Layer):
>>> v = [0, 1, 2, 3, 2]
>>> with tf.device("CPU:0"):
... g = dgl.bipartite((u, v))
... u_fea = th.rand(2, 5)
... v_fea = th.rand(4, 5)
... u_fea = tf.convert_to_tensor(np.random.rand(2, 5))
... v_fea = tf.convert_to_tensor(np.random.rand(4, 5))
... conv = GraphConv(5, 2, norm='both', weight=True, bias=True)
... res = conv(g, (u_fea, v_fea))
>>> res
......@@ -161,6 +161,20 @@ class GraphConv(layers.Layer):
self._activation = activation
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def call(self, graph, feat, weight=None):
r"""
......
......@@ -8,7 +8,11 @@ from .. import utils
class RelGraphConv(layers.Layer):
r"""Relational graph convolution layer.
r"""
Description
-----------
Relational graph convolution layer.
Relational graph convolution is introduced in "`Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__"
......@@ -30,39 +34,85 @@ class RelGraphConv(layers.Layer):
W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}
where :math:`B` is the number of bases.
where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined
with coefficients :math:`a_{rb}^{(l)}`.
The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B`
number of block diagonal matrices. We refer :math:`B` as the number of bases.
The block regularization decomposes :math:`W_r` by:
.. math::
W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)}
where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block
bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`.
Parameters
----------
in_feat : int
Input feature size.
Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.
out_feat : int
Output feature size.
Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
num_rels : int
Number of relations.
Number of relations. .
regularizer : str
Which weight regularizer to use "basis" or "bdd"
Which weight regularizer to use "basis" or "bdd".
"basis" is short for basis-diagonal-decomposition.
"bdd" is short for block-diagonal-decomposition.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
Number of bases. If is none, use number of relations. Default: ``None``.
bias : bool, optional
True if bias is added. Default: True
True if bias is added. Default: ``True``.
activation : callable, optional
Activation function. Default: None
Activation function. Default: ``None``.
self_loop : bool, optional
True to include self loop message. Default: False
True to include self loop message. Default: ``True``.
low_mem : bool, optional
True to use low memory implementation of relation message passing function. Default: False
This option trade speed with memory consumption, and will slowdown the forward/backward.
Turn it on when you encounter OOM problem during training or evaluation.
True to use low memory implementation of relation message passing function. Default: False.
This option trades speed with memory consumption, and will slowdown the forward/backward.
Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``.
dropout : float, optional
Dropout rate. Default: 0.0
Dropout rate. Default: ``0.0``
layer_norm: float, optional
Add layer norm. Default: False
"""
Add layer norm. Default: ``False``
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import RelGraphConv
>>>
>>> with tf.device("CPU:0"):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = tf.ones((6, 10))
>>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)
>>> etype = tf.convert_to_tensor(np.array([0,1,2,0,1,2]).astype(np.int64))
>>> res = conv(g, feat, etype)
>>> res
<tf.Tensor: shape=(6, 2), dtype=float32, numpy=
array([[-0.02938664, 1.7932655 ],
[ 0.1146394 , 0.48319 ],
[-0.02938664, 1.7932655 ],
[ 1.2054908 , -0.26098895],
[ 0.1146394 , 0.48319 ],
[ 0.75915515, 1.1454091 ]], dtype=float32)>
>>> # One-hot input
>>> with tf.device("CPU:0"):
>>> one_hot_feat = tf.convert_to_tensor(np.array([0,1,2,3,4,5]).astype(np.int64))
>>> res = conv(g, one_hot_feat, etype)
>>> res
<tf.Tensor: shape=(6, 2), dtype=float32, numpy=
array([[-0.24205256, -0.7922753 ],
[ 0.62085056, 0.4893622 ],
[-0.9484881 , -0.26546806],
[-0.2163915 , -0.12585883],
[-0.14293689, 0.77483284],
[ 0.091169 , -0.06761569]], dtype=float32)>
"""
def __init__(self,
in_feat,
out_feat,
......@@ -71,7 +121,7 @@ class RelGraphConv(layers.Layer):
num_bases=None,
bias=True,
activation=None,
self_loop=False,
self_loop=True,
low_mem=False,
dropout=0.0,
layer_norm=False):
......@@ -206,6 +256,7 @@ class RelGraphConv(layers.Layer):
The graph.
x : tf.Tensor
Input node features. Could be either
* :math:`(|V|, D)` dense tensor
* :math:`(|V|,)` int64 vector, representing the categorical values of each
node. We then treat the input feature as an one-hot encoding feature.
......
......@@ -8,24 +8,30 @@ from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(layers.Layer):
r"""GraphSAGE layer from paper `Inductive Representation Learning on
r"""
Description
-----------
GraphSAGE layer from paper `Inductive Representation Learning on
Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.
.. math::
h_{\mathcal{N}(i)}^{(l+1)} & = \mathrm{aggregate}
h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
\left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)
h_{i}^{(l+1)} & = \sigma \left(W \cdot \mathrm{concat}
(h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1} + b) \right)
h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat}
(h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right)
h_{i}^{(l+1)} & = \mathrm{norm}(h_{i}^{l})
h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{l})
Parameters
----------
in_feats : int, or pair of ints
Input feature size.
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
GATConv can be applied on homogeneous graph and unidirectional
`bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
If the layer applies on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
......@@ -33,7 +39,7 @@ class SAGEConv(layers.Layer):
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
Output feature size.
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
feat_drop : float
Dropout rate on features, default: ``0``.
aggregator_type : str
......@@ -45,8 +51,46 @@ class SAGEConv(layers.Layer):
activation : callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
"""
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import SAGEConv
>>>
>>> # Case 1: Homogeneous graph
>>> with tf.device("CPU:0"):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = tf.ones((6, 10))
>>> conv = SAGEConv(10, 2, 'pool')
>>> res = conv(g, feat)
>>> res
<tf.Tensor: shape=(6, 2), dtype=float32, numpy=
array([[-3.6633523 , -0.90711546],
[-3.6633523 , -0.90711546],
[-3.6633523 , -0.90711546],
[-3.6633523 , -0.90711546],
[-3.6633523 , -0.90711546],
[-3.6633523 , -0.90711546]], dtype=float32)>
>>> # Case 2: Unidirectional bipartite graph
>>> with tf.device("CPU:0"):
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.bipartite((u, v))
>>> u_fea = tf.convert_to_tensor(np.random.rand(2, 5))
>>> v_fea = tf.convert_to_tensor(np.random.rand(4, 5))
>>> conv = SAGEConv((5, 10), 2, 'mean')
>>> res = conv(g, (u_fea, v_fea))
>>> res
<tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[-0.59453356, -0.4055441 ],
[-0.47459763, -0.717764 ],
[ 0.3221837 , -0.29876417],
[-0.63356155, 0.09390211]], dtype=float32)>
"""
def __init__(self,
in_feats,
out_feats,
......@@ -82,17 +126,21 @@ class SAGEConv(layers.Layer):
return {'neigh': rst}
def call(self, graph, feat):
r"""Compute GraphSAGE layer.
r"""
Description
-----------
Compute GraphSAGE layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor or pair of tf.Tensor
If a single tensor is given, it represents the input feature of shape
If a tf.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tensors are given, the pair must contain two tensors of shape
If a pair of tf.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Returns
......
......@@ -5,53 +5,124 @@ from tensorflow.keras import layers
import numpy as np
from .... import function as fn
from ....base import DGLError
class SGConv(layers.Layer):
r"""Simplifying Graph Convolution layer from paper `Simplifying Graph
r"""
Description
-----------
Simplifying Graph Convolution layer from paper `Simplifying Graph
Convolutional Networks <https://arxiv.org/pdf/1902.07153.pdf>`__.
.. math::
H^{l+1} = (\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2})^K H^{l} \Theta^{l}
H^{K} = (\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2})^K X \Theta
where :math:`\tilde{A}` is :math:`A` + :math:`I`.
Thus the graph input is expected to have self-loop edges added.
Parameters
----------
in_feats : int
Number of input features.
Number of input features; i.e, the number of dimensions of :math:`X`.
out_feats : int
Number of output features.
Number of output features; i.e, the number of dimensions of :math:`H^{K}`.
k : int
Number of hops :math:`K`. Defaults:``1``.
cached : bool
If True, the module would cache
.. math::
(\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}})^K X\Theta
(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}})^K X\Theta
at the first forward call. This parameter should only be set to
``True`` in Transductive Learning setting.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
norm : callable activation function/layer or None, optional
If not None, applies normalization to the updated node features.
"""
If not None, applies normalization to the updated node features. Default: ``False``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Default: ``False``.
Notes
-----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be appied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zere-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv.
Example
-------
>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import SGConv
>>>
>>> with tf.device("CPU:0"):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = tf.ones((6, 10))
>>> conv = SGConv(10, 2, k=2, cached=True)
>>> res = conv(g, feat)
>>> res
<tf.Tensor: shape=(6, 2), dtype=float32, numpy=
array([[0.61023676, 0.5246612 ],
[0.61023676, 0.5246612 ],
[0.61023676, 0.5246612 ],
[0.8697353 , 0.7477695 ],
[0.60570633, 0.520766 ],
[0.6102368 , 0.52466124]], dtype=float32)>
"""
def __init__(self,
in_feats,
out_feats,
k=1,
cached=False,
bias=True,
norm=None):
norm=None,
allow_zero_in_degree=False):
super(SGConv, self).__init__()
self.fc = layers.Dense(out_feats, use_bias=bias)
self._cached = cached
self._cached_h = None
self._k = k
self.norm = norm
self._allow_zero_in_degree = allow_zero_in_degree
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
def call(self, graph, feat):
r"""Compute Simplifying Graph Convolution layer.
r"""
Description
-----------
Compute Simplifying Graph Convolution layer.
Parameters
----------
......@@ -67,12 +138,31 @@ class SGConv(layers.Layer):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
Raises
------
DGLError
If there are 0-in-degree nodes in the input graph, it will raise DGLError
since no message will be passed to those nodes. This will cause invalid output.
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
Notes
-----
If ``cache`` is se to True, ``feat`` and ``graph`` should not change during
If ``cache`` is set to True, ``feat`` and ``graph`` should not change during
training, or you will get wrong results.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if self._cached_h is not None:
feat = self._cached_h
else:
......
......@@ -107,8 +107,9 @@ class HeteroGraphConv(layers.Layer):
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
if hasattr(v, '_allow_zero_in_degree'):
v._allow_zero_in_degree = True
set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None)
if callable(set_allow_zero_in_degree_fn):
set_allow_zero_in_degree_fn(True)
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
......
......@@ -157,7 +157,7 @@ def test_tagconv():
assert h1.shape[-1] == 2
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_gat_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -169,7 +169,7 @@ def test_gat_conv(g, idtype):
assert h.shape == (g.number_of_nodes(), 5, 20)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_gat_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -254,7 +254,7 @@ def test_cheb_conv():
assert h1.shape == (20, 20)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_agnn_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -266,7 +266,7 @@ def test_agnn_conv(g, idtype):
assert h.shape == (g.number_of_nodes(), 10)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_agnn_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -359,7 +359,7 @@ def test_dense_sage_conv(idtype, g):
assert F.allclose(out_sage, out_dense_sage)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_edge_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -372,7 +372,7 @@ def test_edge_conv(g, idtype):
assert h1.shape == (g.number_of_nodes(), 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_edge_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -419,7 +419,7 @@ def test_gin_conv_bi(g, idtype, aggregator_type):
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_gmm_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -431,7 +431,7 @@ def test_gmm_conv(g, idtype):
assert h1.shape == (g.number_of_nodes(), 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_gmm_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......
......@@ -194,6 +194,7 @@ def test_rgcn():
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
rgc_basis_low.loop_weight = rgc_basis.loop_weight
h = tf.random.normal((100, I))
r = tf.constant(etype)
h_new = rgc_basis(g, h, r)
......@@ -205,6 +206,7 @@ def test_rgcn():
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
rgc_bdd_low.weight = rgc_bdd.weight
rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
h = tf.random.normal((100, I))
r = tf.constant(etype)
h_new = rgc_bdd(g, h, r)
......@@ -220,6 +222,7 @@ def test_rgcn():
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
rgc_basis_low.loop_weight = rgc_basis.loop_weight
h = tf.random.normal((100, I))
r = tf.constant(etype)
h_new = rgc_basis(g, h, r, norm)
......@@ -231,6 +234,7 @@ def test_rgcn():
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
rgc_bdd_low.weight = rgc_bdd.weight
rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
h = tf.random.normal((100, I))
r = tf.constant(etype)
h_new = rgc_bdd(g, h, r, norm)
......@@ -244,6 +248,7 @@ def test_rgcn():
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
rgc_basis_low.loop_weight = rgc_basis.loop_weight
h = tf.constant(np.random.randint(0, I, (100,))) * 1
r = tf.constant(etype) * 1
h_new = rgc_basis(g, h, r)
......@@ -253,7 +258,7 @@ def test_rgcn():
assert F.allclose(h_new, h_new_low)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite']))
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_gat_conv(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......@@ -263,7 +268,7 @@ def test_gat_conv(g, idtype):
assert h.shape == (g.number_of_nodes(), 4, 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite']))
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
def test_gat_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx())
ctx = F.ctx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment