gatconv.py 13.4 KB
Newer Older
1
2
3
4
5
6
7
8
"""MXNet modules for graph attention networks(GAT)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import mxnet as mx
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Identity

from .... import function as fn
9
from ....base import DGLError
10
from ...functional import edge_softmax
11
from ....utils import expand_as_pair
12
13
14

#pylint: enable=W0235
class GATConv(nn.Block):
15
16
    r"""Graph attention layer from `Graph Attention Network
    <https://arxiv.org/pdf/1710.10903.pdf>`__
17
18
19
20
21
22
23
24

    .. math::
        h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}

    where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and
    node :math:`j`:

    .. math::
25
        \alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})
26

27
        e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)
28
29
30

    Parameters
    ----------
31
32
33
34
35
36
37
38
    in_feats : int, or pair of ints
        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
        GATConv can be applied on homogeneous graph and unidirectional
        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
        If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
        specifies the input feature size on both the source and destination nodes.  If
        a scalar is given, the source and destination node feature size would take the
        same value.
39
    out_feats : int
40
        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
41
42
43
    num_heads : int
        Number of heads in Multi-Head Attention.
    feat_drop : float, optional
44
        Dropout rate on feature. Defaults: ``0``.
45
    attn_drop : float, optional
46
        Dropout rate on attention weight. Defaults: ``0``.
47
    negative_slope : float, optional
48
        LeakyReLU angle of negative slope. Defaults: ``0.2``.
49
    residual : bool, optional
50
        If True, use residual connection. Defaults: ``False``.
51
52
53
    activation : callable activation function/layer or None, optional.
        If not None, applies an activation function to the updated node features.
        Default: ``None``.
54
55
56
57
58
59
60
    allow_zero_in_degree : bool, optional
        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
        since no message will be passed to those nodes. This is harmful for some applications
        causing silent performance regression. This module will raise a DGLError if it detects
        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
        and let the users handle it by themselves. Defaults: ``False``.

61
62
    Note
    ----
63
64
65
66
67
68
69
70
71
72
    Zero in-degree nodes will lead to invalid output value. This is because no message
    will be passed to those nodes, the aggregation function will be appied on empty input.
    A common practice to avoid this is to add a self-loop for each node in the graph if
    it is homogeneous, which can be achieved by:

    >>> g = ... # a DGLGraph
    >>> g = dgl.add_self_loop(g)

    Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
    since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
73
74
    to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.
    A common practise to handle this is to filter out the nodes with zero-in-degree when use
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    after conv.

    Examples
    --------
    >>> import dgl
    >>> import numpy as np
    >>> import mxnet as mx
    >>> from mxnet import gluon
    >>> from dgl.nn import GATConv
    >>>
    >>> # Case 1: Homogeneous graph
    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
    >>> g = dgl.add_self_loop(g)
    >>> feat = mx.nd.ones((6, 10))
    >>> gatconv = GATConv(10, 2, num_heads=3)
    >>> gatconv.initialize(ctx=mx.cpu(0))
    >>> res = gatconv(g, feat)
    >>> res
    [[[ 0.32368395 -0.10501936]
    [ 1.0839728   0.92690575]
    [-0.54581136 -0.84279203]]
    [[ 0.32368395 -0.10501936]
    [ 1.0839728   0.92690575]
    [-0.54581136 -0.84279203]]
    [[ 0.32368395 -0.10501936]
    [ 1.0839728   0.92690575]
    [-0.54581136 -0.84279203]]
    [[ 0.32368395 -0.10501937]
    [ 1.0839728   0.9269058 ]
    [-0.5458114  -0.8427921 ]]
    [[ 0.32368395 -0.10501936]
    [ 1.0839728   0.92690575]
    [-0.54581136 -0.84279203]]
    [[ 0.32368395 -0.10501936]
    [ 1.0839728   0.92690575]
    [-0.54581136 -0.84279203]]]
    <NDArray 6x3x2 @cpu(0)>

    >>> # Case 2: Unidirectional bipartite graph
    >>> u = [0, 1, 0, 0, 1]
    >>> v = [0, 1, 2, 3, 2]
116
    >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    >>> u_feat = mx.nd.random.randn(2, 5)
    >>> v_feat = mx.nd.random.randn(4, 10)
    >>> gatconv = GATConv((5,10), 2, 3)
    >>> gatconv.initialize(ctx=mx.cpu(0))
    >>> res = gatconv(g, (u_feat, v_feat))
    >>> res
    [[[-1.01624     1.8138596 ]
    [ 1.2322129  -0.8410206 ]
    [-1.9325689   1.3824553 ]]
    [[ 0.9915016  -1.6564168 ]
    [-0.32610354  0.42505783]
    [ 1.5278397  -0.92114615]]
    [[-0.32592064  0.62067866]
    [ 0.6162219  -0.3405491 ]
    [-1.356375    0.9988818 ]]
    [[-1.01624     1.8138596 ]
    [ 1.2322129  -0.8410206 ]
    [-1.9325689   1.3824553 ]]]
    <NDArray 4x3x2 @cpu(0)>
136
137
138
139
140
141
142
143
144
    """
    def __init__(self,
                 in_feats,
                 out_feats,
                 num_heads,
                 feat_drop=0.,
                 attn_drop=0.,
                 negative_slope=0.2,
                 residual=False,
145
146
                 activation=None,
                 allow_zero_in_degree=False):
147
148
        super(GATConv, self).__init__()
        self._num_heads = num_heads
149
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
150
151
        self._in_feats = in_feats
        self._out_feats = out_feats
152
        self._allow_zero_in_degree = allow_zero_in_degree
153
        with self.name_scope():
154
155
156
157
158
159
160
161
162
163
164
            if isinstance(in_feats, tuple):
                self.fc_src = nn.Dense(out_feats * num_heads, use_bias=False,
                                       weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
                                       in_units=self._in_src_feats)
                self.fc_dst = nn.Dense(out_feats * num_heads, use_bias=False,
                                       weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
                                       in_units=self._in_dst_feats)
            else:
                self.fc = nn.Dense(out_feats * num_heads, use_bias=False,
                                   weight_initializer=mx.init.Xavier(magnitude=math.sqrt(2.0)),
                                   in_units=in_feats)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
            self.attn_l = self.params.get('attn_l',
                                          shape=(1, num_heads, out_feats),
                                          init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
            self.attn_r = self.params.get('attn_r',
                                          shape=(1, num_heads, out_feats),
                                          init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
            self.feat_drop = nn.Dropout(feat_drop)
            self.attn_drop = nn.Dropout(attn_drop)
            self.leaky_relu = nn.LeakyReLU(negative_slope)
            if residual:
                if in_feats != out_feats:
                    self.res_fc = nn.Dense(out_feats * num_heads, use_bias=False,
                                           weight_initializer=mx.init.Xavier(
                                               magnitude=math.sqrt(2.0)),
                                           in_units=in_feats)
                else:
                    self.res_fc = Identity()
            else:
                self.res_fc = None
            self.activation = activation

186
187
188
189
190
191
192
193
194
195
196
197
198
199
    def set_allow_zero_in_degree(self, set_value):
        r"""

        Description
        -----------
        Set allow_zero_in_degree flag.

        Parameters
        ----------
        set_value : bool
            The value to be set to the flag.
        """
        self._allow_zero_in_degree = set_value

200
    def forward(self, graph, feat, get_attention=False):
201
202
203
204
205
        r"""

        Description
        -----------
        Compute graph attention network layer.
206
207
208
209
210

        Parameters
        ----------
        graph : DGLGraph
            The graph.
211
        feat : mxnet.NDArray or pair of mxnet.NDArray
212
            If a mxnet.NDArray is given, the input feature of shape :math:`(N, *, D_{in})` where
213
214
            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of mxnet.NDArray is given, the pair must contain two tensors of shape
215
            :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`.
216
217
        get_attention : bool, optional
            Whether to return the attention values. Default to False.
218
219
220
221

        Returns
        -------
        mxnet.NDArray
222
            The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H`
223
            is the number of heads, and :math:`D_{out}` is size of output feature.
224
        mxnet.NDArray, optional
225
            The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of
226
            edges. This is returned only when :attr:`get_attention` is ``True``.
227
228
229
230
231
232
233

        Raises
        ------
        DGLError
            If there are 0-in-degree nodes in the input graph, it will raise DGLError
            since no message will be passed to those nodes. This will cause invalid output.
            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
234
        """
235
        with graph.local_scope():
236
237
238
239
240
241
242
243
244
245
246
247
            if not self._allow_zero_in_degree:
                if graph.in_degrees().min() == 0:
                    raise DGLError('There are 0-in-degree nodes in the graph, '
                                   'output for those nodes will be invalid. '
                                   'This is harmful for some applications, '
                                   'causing silent performance regression. '
                                   'Adding self-loop on the input graph by '
                                   'calling `g = dgl.add_self_loop(g)` will resolve '
                                   'the issue. Setting ``allow_zero_in_degree`` '
                                   'to be `True` when constructing this module will '
                                   'suppress the check and let the code run.')

248
            if isinstance(feat, tuple):
249
250
251
                src_prefix_shape = feat[0].shape[:-1]
                dst_prefix_shape = feat[1].shape[:-1]
                feat_dim = feat[0].shape[-1]
252
253
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
254
255
                if not hasattr(self, 'fc_src'):
                    self.fc_src, self.fc_dst = self.fc, self.fc
256
257
258
259
                feat_src = self.fc_src(h_src.reshape(-1, feat_dim)).reshape(
                    *src_prefix_shape, self._num_heads, self._out_feats)
                feat_dst = self.fc_dst(h_dst.reshape(-1, feat_dim)).reshape(
                    *dst_prefix_shape, self._num_heads, self._out_feats)
260
            else:
261
262
                src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
                feat_dim = feat[0].shape[-1]
263
                h_src = h_dst = self.feat_drop(feat)
264
265
                feat_src = feat_dst = self.fc(h_src.reshape(-1, feat_dim)).reshape(
                    *src_prefix_shape, self._num_heads, self._out_feats)
266
267
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
268
269
                    h_dst = h_dst[:graph.number_of_dst_nodes()]
                    dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            # NOTE: GAT paper uses "first concatenation then linear projection"
            # to compute attention scores, while ours is "first projection then
            # addition", the two approaches are mathematically equivalent:
            # We decompose the weight vector a mentioned in the paper into
            # [a_l || a_r], then
            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
            # Our implementation is much efficient because we do not need to
            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
            # addition could be optimized with DGL's built-in function u_add_v,
            # which further speeds up computation and saves memory footprint.
            el = (feat_src * self.attn_l.data(feat_src.context)).sum(axis=-1).expand_dims(-1)
            er = (feat_dst * self.attn_r.data(feat_src.context)).sum(axis=-1).expand_dims(-1)
            graph.srcdata.update({'ft': feat_src, 'el': el})
            graph.dstdata.update({'er': er})
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = self.leaky_relu(graph.edata.pop('e'))
            # compute softmax
            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
            graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                             fn.sum('m', 'ft'))
            rst = graph.dstdata['ft']
            # residual
            if self.res_fc is not None:
294
295
                resval = self.res_fc(h_dst.reshape(-1, feat_dim)).reshape(
                    *dst_prefix_shape, -1, self._out_feats)
296
297
298
299
                rst = rst + resval
            # activation
            if self.activation:
                rst = self.activation(rst)
300
301
302
303
304

            if get_attention:
                return rst, graph.edata['a']
            else:
                return rst