sageconv.py 10.6 KB
Newer Older
1
2
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
Mufei Li's avatar
Mufei Li committed
3
import torch
4
5
6
7
from torch import nn
from torch.nn import functional as F

from .... import function as fn
8
from ....base import DGLError
9
from ....utils import expand_as_pair, check_eq_shape
10
11
12


class SAGEConv(nn.Module):
13
14
    r"""GraphSAGE layer from `Inductive Representation Learning on
    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__
15
16

    .. math::
17
        h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
18
19
        \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)

20
21
        h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat}
        (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right)
22

23
        h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{(l+1)})
24

25
26
27
    If a weight tensor on each edge is provided, the aggregation becomes:

    .. math::
28
        h_{\mathcal{N}(i)}^{(l+1)} = \mathrm{aggregate}
29
30
31
        \left(\{e_{ji} h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)

    where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
32
    Please make sure that :math:`e_{ji}` is broadcastable with :math:`h_j^{l}`.
33

34
35
    Parameters
    ----------
36
    in_feats : int, or pair of ints
37
        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
38

Junki Ishikawa's avatar
Junki Ishikawa committed
39
        SAGEConv can be applied on homogeneous graph and unidirectional
40
41
        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
        If the layer applies on a unidirectional bipartite graph, ``in_feats``
42
43
44
45
46
47
        specifies the input feature size on both the source and destination nodes.  If
        a scalar is given, the source and destination node feature size would take the
        same value.

        If aggregator type is ``gcn``, the feature size of source and destination nodes
        are required to be the same.
48
    out_feats : int
49
        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
50
51
    aggregator_type : str
        Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
52
53
    feat_drop : float
        Dropout rate on features, default: ``0``.
54
55
56
57
58
59
60
    bias : bool
        If True, adds a learnable bias to the output. Default: ``True``.
    norm : callable activation function/layer or None, optional
        If not None, applies normalization to the updated node features.
    activation : callable activation function/layer or None, optional
        If not None, applies an activation function to the updated node features.
        Default: ``None``.
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    Examples
    --------
    >>> import dgl
    >>> import numpy as np
    >>> import torch as th
    >>> from dgl.nn import SAGEConv

    >>> # Case 1: Homogeneous graph
    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
    >>> g = dgl.add_self_loop(g)
    >>> feat = th.ones(6, 10)
    >>> conv = SAGEConv(10, 2, 'pool')
    >>> res = conv(g, feat)
    >>> res
    tensor([[-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099]], grad_fn=<AddBackward0>)

    >>> # Case 2: Unidirectional bipartite graph
    >>> u = [0, 1, 0, 0, 1]
    >>> v = [0, 1, 2, 3, 2]
86
    >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)})
87
88
89
90
91
92
93
94
95
    >>> u_fea = th.rand(2, 5)
    >>> v_fea = th.rand(4, 10)
    >>> conv = SAGEConv((5, 10), 2, 'mean')
    >>> res = conv(g, (u_fea, v_fea))
    >>> res
    tensor([[ 0.3163,  3.1166],
            [ 0.3866,  2.5398],
            [ 0.5873,  1.6597],
            [-0.2502,  2.8068]], grad_fn=<AddBackward0>)
96
97
98
99
100
101
102
103
    """
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
104
                 activation=None):
105
        super(SAGEConv, self).__init__()
106
107
108
109
110
111
        valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
        if aggregator_type not in valid_aggre_types:
            raise DGLError(
                'Invalid aggregator_type. Must be one of {}. '
                'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
            )
112

113
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
114
115
116
117
118
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
119

120
121
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
122
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
123
        if aggregator_type == 'lstm':
124
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
125

126
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
127
128
129
130

        if aggregator_type != 'gcn':
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        elif bias:
131
            self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
132

133
134
135
        self.reset_parameters()

    def reset_parameters(self):
136
137
138
139
140
141
        r"""

        Description
        -----------
        Reinitialize learnable parameters.

142
143
        Note
        ----
144
145
146
        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
        The LSTM module is using xavier initialization method for its weights.
        """
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
163
164
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
165
166
167
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

168
    def forward(self, graph, feat, edge_weight=None):
169
170
171
172
173
        r"""

        Description
        -----------
        Compute GraphSAGE layer.
174
175
176
177
178

        Parameters
        ----------
        graph : DGLGraph
            The graph.
179
        feat : torch.Tensor or pair of torch.Tensor
180
181
182
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
183
184
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
185
186
187
        edge_weight : torch.Tensor, optional
            Optional tensor on the edge. If given, the convolution will weight
            with regard to the message.
188
189
190
191

        Returns
        -------
        torch.Tensor
WangYQ's avatar
WangYQ committed
192
193
            The output feature of shape :math:`(N_{dst}, D_{out})`
            where :math:`N_{dst}` is the number of destination nodes in the input graph,
194
            :math:`D_{out}` is the size of the output feature.
195
        """
196
197
198
199
200
201
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
202
203
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
204
            msg_fn = fn.copy_u('h', 'm')
205
206
207
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.number_of_edges()
                graph.edata['_edge_weight'] = edge_weight
208
                msg_fn = fn.u_mul_e('h', '_edge_weight', 'm')
209
210
211

            h_self = feat_dst

Mufei Li's avatar
Mufei Li committed
212
213
214
215
216
            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats).to(feat_dst)

217
218
219
220
            # Determine whether to apply linear transformation before message passing A(XW)
            lin_before_mp = self._in_src_feats > self._out_feats

            # Message Passing
221
            if self._aggre_type == 'mean':
222
223
                graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
                graph.update_all(msg_fn, fn.mean('m', 'neigh'))
224
                h_neigh = graph.dstdata['neigh']
225
226
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)
227
228
            elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
229
230
231
232
                graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
                if isinstance(feat, tuple):  # heterogeneous
                    graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
                else:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
233
234
235
236
                    if graph.is_block:
                        graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()]
                    else:
                        graph.dstdata['h'] = graph.srcdata['h']
237
                graph.update_all(msg_fn, fn.sum('m', 'neigh'))
238
239
240
                # divide in_degrees
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
241
242
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)
243
244
            elif self._aggre_type == 'pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
245
246
                graph.update_all(msg_fn, fn.max('m', 'neigh'))
                h_neigh = self.fc_neigh(graph.dstdata['neigh'])
247
248
            elif self._aggre_type == 'lstm':
                graph.srcdata['h'] = feat_src
249
250
                graph.update_all(msg_fn, self._lstm_reducer)
                h_neigh = self.fc_neigh(graph.dstdata['neigh'])
251
252
253
254
255
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

            # GraphSAGE GCN does not require fc_self.
            if self._aggre_type == 'gcn':
256
                rst = h_neigh
257
258
259
                # add bias manually for GCN
                if self.bias is not None:
                    rst = rst + self.bias
260
            else:
261
262
263
                rst = self.fc_self(h_self) + h_neigh


264
265
266
267
268
269
270
            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)
            return rst