"vscode:/vscode.git/clone" did not exist on "38cf0dc1c52138987e6e66295c5a2d192a6914bd"
minibatch-nn.rst 8.42 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-custom-gnn-module:

6.5 Implementing Custom GNN Module for Mini-batch Training
-------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-custom-gnn-module>`

8
9
10
11
12
13
.. note::

   :doc:`This tutorial <tutorials/large/L4_message_passing>` has similar
   content to this section for the homogeneous graph case.


14
15
16
If you were familiar with how to write a custom GNN module for updating
the entire graph for homogeneous or heterogeneous graphs (see
:ref:`guide-nn`), the code for computing on
17
MFGs is similar, with the exception that the nodes are divided into
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
input nodes and output nodes.

For example, consider the following custom graph convolution module
code. Note that it is not necessarily among the most efficient implementations
- they only serve for an example of how a custom GNN module could look
like.

.. code:: python

    class CustomGraphConv(nn.Module):
        def __init__(self, in_feats, out_feats):
            super().__init__()
            self.W = nn.Linear(in_feats * 2, out_feats)
    
        def forward(self, g, h):
            with g.local_scope():
                g.ndata['h'] = h
                g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
                return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))

If you have a custom message passing NN module for the full graph, and
39
you would like to make it work for MFGs, you only need to rewrite the
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
forward function as follows. Note that the corresponding statements from
the full-graph implementation are commented; you can compare the
original statements with the new statements.

.. code:: python

    class CustomGraphConv(nn.Module):
        def __init__(self, in_feats, out_feats):
            super().__init__()
            self.W = nn.Linear(in_feats * 2, out_feats)
    
        # h is now a pair of feature tensors for input and output nodes, instead of
        # a single feature tensor.
        # def forward(self, g, h):
        def forward(self, block, h):
            # with g.local_scope():
            with block.local_scope():
                # g.ndata['h'] = h
                h_src = h
                h_dst = h[:block.number_of_dst_nodes()]
                block.srcdata['h'] = h_src
                block.dstdata['h'] = h_dst
    
                # g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
                block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
    
                # return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))
                return self.W(torch.cat(
                    [block.dstdata['h'], block.dstdata['h_neigh']], 1))

In general, you need to do the following to make your NN module work for
71
MFGs.
72
73
74

-  Obtain the features for output nodes from the input features by
   slicing the first few rows. The number of rows can be obtained by
75
   :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>`.
76
-  Replace
77
78
79
   :attr:`g.ndata <dgl.DGLGraph.ndata>` with either
   :attr:`block.srcdata <dgl.DGLGraph.srcdata>` for features on input nodes or
   :attr:`block.dstdata <dgl.DGLGraph.dstdata>` for features on output nodes, if
80
81
   the original graph has only one node type.
-  Replace
82
83
84
   :attr:`g.nodes <dgl.DGLGraph.nodes>` with either
   :attr:`block.srcnodes <dgl.DGLGraph.srcnodes>` for features on input nodes or
   :attr:`block.dstnodes <dgl.DGLGraph.dstnodes>` for features on output nodes,
85
86
   if the original graph has multiple node types.
-  Replace
87
   :meth:`g.num_nodes <dgl.DGLGraph.num_nodes>` with either
88
89
   :meth:`block.number_of_src_nodes <dgl.DGLGraph.number_of_src_nodes>` or
   :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>` for the number of
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
   input nodes or output nodes respectively.

Heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~

For heterogeneous graph the way of writing custom GNN modules is
similar. For instance, consider the following module that work on full
graph.

.. code:: python

    class CustomHeteroGraphConv(nn.Module):
        def __init__(self, g, in_feats, out_feats):
            super().__init__()
            self.Ws = nn.ModuleDict()
            for etype in g.canonical_etypes:
                utype, _, vtype = etype
                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
            for ntype in g.ntypes:
                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])
    
        def forward(self, g, h):
            with g.local_scope():
                for ntype in g.ntypes:
                    g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
                    g.nodes[ntype].data['h_src'] = h[ntype]
                for etype in g.canonical_etypes:
                    utype, _, vtype = etype
                    g.update_all(
                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
                        etype=etype)
                    g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \
                        self.Ws[etype](g.nodes[vtype].data['h_neigh'])
                return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}

For ``CustomHeteroGraphConv``, the principle is to replace ``g.nodes``
with ``g.srcnodes`` or ``g.dstnodes`` depend on whether the features
serve for input or output.

.. code:: python

    class CustomHeteroGraphConv(nn.Module):
        def __init__(self, g, in_feats, out_feats):
            super().__init__()
            self.Ws = nn.ModuleDict()
            for etype in g.canonical_etypes:
                utype, _, vtype = etype
                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
            for ntype in g.ntypes:
                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])
    
        def forward(self, g, h):
            with g.local_scope():
                for ntype in g.ntypes:
                    h_src, h_dst = h[ntype]
                    g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
                    g.srcnodes[ntype].data['h_src'] = h[ntype]
                for etype in g.canonical_etypes:
                    utype, _, vtype = etype
                    g.update_all(
                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
                        etype=etype)
                    g.dstnodes[vtype].data['h_dst'] = \
                        g.dstnodes[vtype].data['h_dst'] + \
                        self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])
                return {ntype: g.dstnodes[ntype].data['h_dst']
                        for ntype in g.ntypes}

158
159
Writing modules that work on homogeneous graphs, bipartite graphs, and MFGs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160
161
162

All message passing modules in DGL work on homogeneous graphs,
unidirectional bipartite graphs (that have two node types and one edge
163
type), and a MFG with one edge type. Essentially, the input graph and
164
165
166
167
168
169
feature of a builtin DGL neural network module must satisfy either of
the following cases.

-  If the input feature is a pair of tensors, then the input graph must
   be unidirectional bipartite.
-  If the input feature is a single tensor and the input graph is a
170
   MFG, DGL will automatically set the feature on the output nodes as
171
172
   the first few rows of the input node features.
-  If the input feature must be a single tensor and the input graph is
173
   not a MFG, then the input graph must be homogeneous.
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

For example, the following is simplified from the PyTorch implementation
of :class:`dgl.nn.pytorch.SAGEConv` (also available in MXNet and Tensorflow)
(removing normalization and dealing with only mean aggregation etc.).

.. code:: python

    import dgl.function as fn
    class SAGEConv(nn.Module):
        def __init__(self, in_feats, out_feats):
            super().__init__()
            self.W = nn.Linear(in_feats * 2, out_feats)
    
        def forward(self, g, h):
            if isinstance(h, tuple):
                h_src, h_dst = h
            elif g.is_block:
                h_src = h
                h_dst = h[:g.number_of_dst_nodes()]
            else:
                h_src = h_dst = h
                 
            g.srcdata['h'] = h_src
            g.dstdata['h'] = h_dst
            g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))
            return F.relu(
                self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))

:ref:`guide-nn` also provides a walkthrough on :class:`dgl.nn.pytorch.SAGEConv`,
203
which works on unidirectional bipartite graphs, homogeneous graphs, and MFGs.
204
205