sageconv.py 9.27 KB
Newer Older
1
2
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
Mufei Li's avatar
Mufei Li committed
3
import torch
4
5
6
7
from torch import nn
from torch.nn import functional as F

from .... import function as fn
8
from ....utils import expand_as_pair, check_eq_shape
9
10
11


class SAGEConv(nn.Module):
12
13
14
15
16
    r"""

    Description
    -----------
    GraphSAGE layer from paper `Inductive Representation Learning on
17
18
19
    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.

    .. math::
20
        h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
21
22
        \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)

23
24
        h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat}
        (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right)
25

26
        h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{l})
27

28
29
30
31
32
33
34
35
36
    If a weight tensor on each edge is provided, the aggregation becomes:

    .. math::
        h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
        \left(\{e_{ji} h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)

    where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
    Please make sure that `e_{ji}` is broadcastable with `h_j^{l}`.

37
38
    Parameters
    ----------
39
    in_feats : int, or pair of ints
40
        Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
41

Junki Ishikawa's avatar
Junki Ishikawa committed
42
        SAGEConv can be applied on homogeneous graph and unidirectional
43
44
        `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__.
        If the layer applies on a unidirectional bipartite graph, ``in_feats``
45
46
47
48
49
50
        specifies the input feature size on both the source and destination nodes.  If
        a scalar is given, the source and destination node feature size would take the
        same value.

        If aggregator type is ``gcn``, the feature size of source and destination nodes
        are required to be the same.
51
    out_feats : int
52
        Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
53
54
55
56
57
58
59
60
61
62
63
    feat_drop : float
        Dropout rate on features, default: ``0``.
    aggregator_type : str
        Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
    bias : bool
        If True, adds a learnable bias to the output. Default: ``True``.
    norm : callable activation function/layer or None, optional
        If not None, applies normalization to the updated node features.
    activation : callable activation function/layer or None, optional
        If not None, applies an activation function to the updated node features.
        Default: ``None``.
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    Examples
    --------
    >>> import dgl
    >>> import numpy as np
    >>> import torch as th
    >>> from dgl.nn import SAGEConv

    >>> # Case 1: Homogeneous graph
    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
    >>> g = dgl.add_self_loop(g)
    >>> feat = th.ones(6, 10)
    >>> conv = SAGEConv(10, 2, 'pool')
    >>> res = conv(g, feat)
    >>> res
    tensor([[-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099],
            [-1.0888, -2.1099]], grad_fn=<AddBackward0>)

    >>> # Case 2: Unidirectional bipartite graph
    >>> u = [0, 1, 0, 0, 1]
    >>> v = [0, 1, 2, 3, 2]
    >>> g = dgl.bipartite((u, v))
    >>> u_fea = th.rand(2, 5)
    >>> v_fea = th.rand(4, 10)
    >>> conv = SAGEConv((5, 10), 2, 'mean')
    >>> res = conv(g, (u_fea, v_fea))
    >>> res
    tensor([[ 0.3163,  3.1166],
            [ 0.3866,  2.5398],
            [ 0.5873,  1.6597],
            [-0.2502,  2.8068]], grad_fn=<AddBackward0>)
99
100
101
102
103
104
105
106
    """
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
107
                 activation=None):
108
        super(SAGEConv, self).__init__()
109

110
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
111
112
113
114
115
116
117
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
118
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
119
        if aggregator_type == 'lstm':
120
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
121
        if aggregator_type != 'gcn':
122
123
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
124
125
126
        self.reset_parameters()

    def reset_parameters(self):
127
128
129
130
131
132
        r"""

        Description
        -----------
        Reinitialize learnable parameters.

133
134
        Note
        ----
135
136
137
        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
        The LSTM module is using xavier initialization method for its weights.
        """
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
154
155
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
156
157
158
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

159
    def forward(self, graph, feat, edge_weight=None):
160
161
162
163
164
        r"""

        Description
        -----------
        Compute GraphSAGE layer.
165
166
167
168
169

        Parameters
        ----------
        graph : DGLGraph
            The graph.
170
        feat : torch.Tensor or pair of torch.Tensor
171
172
173
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
174
175
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
176
177
178
        edge_weight : torch.Tensor, optional
            Optional tensor on the edge. If given, the convolution will weight
            with regard to the message.
179
180
181
182
183
184
185

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        """
186
187
188
189
190
191
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
192
193
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
194
195
196
197
198
            aggregate_fn = fn.copy_src('h', 'm')
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.number_of_edges()
                graph.edata['_edge_weight'] = edge_weight
                aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
199
200
201

            h_self = feat_dst

Mufei Li's avatar
Mufei Li committed
202
203
204
205
206
            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats).to(feat_dst)

207
208
            if self._aggre_type == 'mean':
                graph.srcdata['h'] = feat_src
209
                graph.update_all(aggregate_fn, fn.mean('m', 'neigh'))
210
211
212
213
214
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst     # same as above if homogeneous
215
                graph.update_all(aggregate_fn, fn.sum('m', 'neigh'))
216
217
218
219
220
                # divide in_degrees
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
            elif self._aggre_type == 'pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
221
                graph.update_all(aggregate_fn, fn.max('m', 'neigh'))
222
223
224
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'lstm':
                graph.srcdata['h'] = feat_src
225
                graph.update_all(aggregate_fn, self._lstm_reducer)
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
                h_neigh = graph.dstdata['neigh']
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

            # GraphSAGE GCN does not require fc_self.
            if self._aggre_type == 'gcn':
                rst = self.fc_neigh(h_neigh)
            else:
                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)
            return rst