sageconv.py 6.54 KB
Newer Older
1
2
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
3
from numbers import Integral
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torch import nn
from torch.nn import functional as F

from .... import function as fn


class SAGEConv(nn.Module):
    r"""GraphSAGE layer from paper `Inductive Representation Learning on
    Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__.

    .. math::
        h_{\mathcal{N}(i)}^{(l+1)} & = \mathrm{aggregate}
        \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)

        h_{i}^{(l+1)} & = \sigma \left(W \cdot \mathrm{concat}
        (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1} + b) \right)

        h_{i}^{(l+1)} & = \mathrm{norm}(h_{i}^{l})

    Parameters
    ----------
25
    in_feats : int, or pair of ints
26
        Input feature size.
27
28
29
30
31
32
33
34

        If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
        specifies the input feature size on both the source and destination nodes.  If
        a scalar is given, the source and destination node feature size would take the
        same value.

        If aggregator type is ``gcn``, the feature size of source and destination nodes
        are required to be the same.
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    out_feats : int
        Output feature size.
    feat_drop : float
        Dropout rate on features, default: ``0``.
    aggregator_type : str
        Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``).
    bias : bool
        If True, adds a learnable bias to the output. Default: ``True``.
    norm : callable activation function/layer or None, optional
        If not None, applies normalization to the updated node features.
    activation : callable activation function/layer or None, optional
        If not None, applies an activation function to the updated node features.
        Default: ``None``.
    """
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()
58
59
60
61
62
63
64
65
66

        if isinstance(in_feats, tuple):
            self._in_src_feats = in_feats[0]
            self._in_dst_feats = in_feats[1]
        elif isinstance(in_feats, Integral):
            self._in_src_feats = self._in_dst_feats = in_feats
        else:
            raise TypeError('in_feats must be either int or pair of ints')

67
68
69
70
71
72
73
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
74
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
75
        if aggregator_type == 'lstm':
76
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
77
        if aggregator_type != 'gcn':
78
79
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
100
101
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
102
103
104
105
106
107
108
109
110
111
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

    def forward(self, graph, feat):
        r"""Compute GraphSAGE layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
112
113
114
115
116
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
117
118
119
120
121
122
123
124

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        """
        graph = graph.local_var()
125

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
126
        if isinstance(feat, tuple):
127
128
            feat_src = self.feat_drop(feat[0])
            feat_dst = self.feat_drop(feat[1])
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
129
130
        else:
            feat_src = feat_dst = self.feat_drop(feat)
131
132
133

        h_self = feat_dst

134
        if self._aggre_type == 'mean':
135
            graph.srcdata['h'] = feat_src
136
            graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
137
            h_neigh = graph.dstdata['neigh']
138
        elif self._aggre_type == 'gcn':
139
140
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst     # same as above if homogeneous
141
142
            graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
            # divide in_degrees
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
143
            degs = graph.in_degrees().to(feat_dst)
144
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
145
        elif self._aggre_type == 'pool':
146
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
147
            graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
148
            h_neigh = graph.dstdata['neigh']
149
        elif self._aggre_type == 'lstm':
150
            graph.srcdata['h'] = feat_src
151
            graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
152
            h_neigh = graph.dstdata['neigh']
153
154
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
155

156
157
158
159
160
161
162
163
164
165
166
167
        # GraphSAGE GCN does not require fc_self.
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        # activation
        if self.activation is not None:
            rst = self.activation(rst)
        # normalization
        if self.norm is not None:
            rst = self.norm(rst)
        return rst