sageconv.py 11 KB
Newer Older
1
2
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
Mufei Li's avatar
Mufei Li committed
3
import torch
4
5
6
7
from torch import nn
from torch.nn import functional as F

from .... import function as fn
8
from ....utils import expand_as_pair, check_eq_shape, dgl_warning
9
10
11


class SAGEConv(nn.Module):
12
13
14
15
16
    r"""

    Description
    -----------
    GraphSAGE layer from paper `Inductive Representation Learning on
17
18
19
    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.

    .. math::
20
        h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
21
22
        \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)

23
24
        h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat}
        (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right)
25

26
        h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{l})
27

28
29
30
    If a weight tensor on each edge is provided, the aggregation becomes:

    .. math::
31
        h_{\mathcal{N}(i)}^{(l+1)} = \mathrm{aggregate}
32
33
34
        \left(\{e_{ji} h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)

    where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
35
    Please make sure that :math:`e_{ji}` is broadcastable with :math:`h_j^{l}`.
36

37
38
    Parameters
    ----------
39
    in_feats : int, or pair of ints
40
        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
41

Junki Ishikawa's avatar
Junki Ishikawa committed
42
        SAGEConv can be applied on homogeneous graph and unidirectional
43
44
        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
        If the layer applies on a unidirectional bipartite graph, ``in_feats``
45
46
47
48
49
50
        specifies the input feature size on both the source and destination nodes.  If
        a scalar is given, the source and destination node feature size would take the
        same value.

        If aggregator type is ``gcn``, the feature size of source and destination nodes
        are required to be the same.
51
    out_feats : int
52
        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
53
54
    aggregator_type : str
        Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
55
56
    feat_drop : float
        Dropout rate on features, default: ``0``.
57
58
59
60
61
62
63
    bias : bool
        If True, adds a learnable bias to the output. Default: ``True``.
    norm : callable activation function/layer or None, optional
        If not None, applies normalization to the updated node features.
    activation : callable activation function/layer or None, optional
        If not None, applies an activation function to the updated node features.
        Default: ``None``.
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    Examples
    --------
    >>> import dgl
    >>> import numpy as np
    >>> import torch as th
    >>> from dgl.nn import SAGEConv

    >>> # Case 1: Homogeneous graph
    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
    >>> g = dgl.add_self_loop(g)
    >>> feat = th.ones(6, 10)
    >>> conv = SAGEConv(10, 2, 'pool')
    >>> res = conv(g, feat)
    >>> res
    tensor([[-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099]], grad_fn=<AddBackward0>)

    >>> # Case 2: Unidirectional bipartite graph
    >>> u = [0, 1, 0, 0, 1]
    >>> v = [0, 1, 2, 3, 2]
    >>> g = dgl.bipartite((u, v))
    >>> u_fea = th.rand(2, 5)
    >>> v_fea = th.rand(4, 10)
    >>> conv = SAGEConv((5, 10), 2, 'mean')
    >>> res = conv(g, (u_fea, v_fea))
    >>> res
    tensor([[ 0.3163,  3.1166],
            [ 0.3866,  2.5398],
            [ 0.5873,  1.6597],
            [-0.2502,  2.8068]], grad_fn=<AddBackward0>)
99
100
101
102
103
104
105
106
    """
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
107
                 activation=None):
108
        super(SAGEConv, self).__init__()
109

110
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
111
112
113
114
115
116
117
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
118
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
119
        if aggregator_type == 'lstm':
120
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
121
        if aggregator_type != 'gcn':
122
123
124
125
126
127
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=False)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
        if bias:
            self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
        else:
            self.register_buffer('bias', None)
128
129
130
        self.reset_parameters()

    def reset_parameters(self):
131
132
133
134
135
136
        r"""

        Description
        -----------
        Reinitialize learnable parameters.

137
138
        Note
        ----
139
140
141
        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
        The LSTM module is using xavier initialization method for its weights.
        """
142
143
144
145
146
147
148
149
150
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

151
152
153
154
155
156
157
158
159
160
161
162
163
    def _compatibility_check(self):
        """Address the backward compatibility issue brought by #2747"""
        if not hasattr(self, 'bias'):
            dgl_warning("You are loading a GraphSAGE model trained from a old version of DGL, "
                        "DGL automatically convert it to be compatible with latest version.")
            bias = self.fc_neigh.bias
            self.fc_neigh.bias = None
            if hasattr(self, 'fc_self'):
                if bias is not None:
                    bias = bias + self.fc_self.bias
                    self.fc_self.bias = None
            self.bias = bias

164
165
166
167
168
169
170
    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
171
172
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
173
174
175
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

176
    def forward(self, graph, feat, edge_weight=None):
177
178
179
180
181
        r"""

        Description
        -----------
        Compute GraphSAGE layer.
182
183
184
185
186

        Parameters
        ----------
        graph : DGLGraph
            The graph.
187
        feat : torch.Tensor or pair of torch.Tensor
188
189
190
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
191
192
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
193
194
195
        edge_weight : torch.Tensor, optional
            Optional tensor on the edge. If given, the convolution will weight
            with regard to the message.
196
197
198
199

        Returns
        -------
        torch.Tensor
WangYQ's avatar
WangYQ committed
200
201
            The output feature of shape :math:`(N_{dst}, D_{out})`
            where :math:`N_{dst}` is the number of destination nodes in the input graph,
202
            :math:`D_{out}` is the size of the output feature.
203
        """
204
        self._compatibility_check()
205
206
207
208
209
210
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
211
212
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
213
            msg_fn = fn.copy_src('h', 'm')
214
215
216
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.number_of_edges()
                graph.edata['_edge_weight'] = edge_weight
217
                msg_fn = fn.u_mul_e('h', '_edge_weight', 'm')
218
219
220

            h_self = feat_dst

Mufei Li's avatar
Mufei Li committed
221
222
223
224
225
            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats).to(feat_dst)

226
227
228
229
            # Determine whether to apply linear transformation before message passing A(XW)
            lin_before_mp = self._in_src_feats > self._out_feats

            # Message Passing
230
            if self._aggre_type == 'mean':
231
232
                graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
                graph.update_all(msg_fn, fn.mean('m', 'neigh'))
233
                h_neigh = graph.dstdata['neigh']
234
235
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)
236
237
            elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
238
239
240
241
                graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
                if isinstance(feat, tuple):  # heterogeneous
                    graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
                else:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
242
243
244
245
                    if graph.is_block:
                        graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()]
                    else:
                        graph.dstdata['h'] = graph.srcdata['h']
246
                graph.update_all(msg_fn, fn.sum('m', 'neigh'))
247
248
249
                # divide in_degrees
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
250
251
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)
252
253
            elif self._aggre_type == 'pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
254
255
                graph.update_all(msg_fn, fn.max('m', 'neigh'))
                h_neigh = self.fc_neigh(graph.dstdata['neigh'])
256
257
            elif self._aggre_type == 'lstm':
                graph.srcdata['h'] = feat_src
258
259
                graph.update_all(msg_fn, self._lstm_reducer)
                h_neigh = self.fc_neigh(graph.dstdata['neigh'])
260
261
262
263
264
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

            # GraphSAGE GCN does not require fc_self.
            if self._aggre_type == 'gcn':
265
                rst = h_neigh
266
            else:
267
268
269
270
271
272
                rst = self.fc_self(h_self) + h_neigh

            # bias term
            if self.bias is not None:
                rst = rst + self.bias

273
274
275
276
277
278
279
            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)
            return rst