minibatch-nn.rst 8.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
.. _guide_cn-minibatch-custom-gnn-module:

6.5 为小批次训练实现定制化的GNN模块
-------------------------------------------------------------

:ref:`(English Version) <guide-minibatch-custom-gnn-module>`

如果用户熟悉如何定制用于更新整个同构图或异构图的GNN模块(参见
:ref:`guide_cn-nn`),那么在块上计算的代码也是类似的,区别只在于节点被划分为输入节点和输出节点。

以下面的自定义图卷积模块代码为例。注意,该代码并不一定是最高效的实现,
此处只是将其作为自定义GNN模块的一个示例。

.. code:: python

    class CustomGraphConv(nn.Module):
        def __init__(self, in_feats, out_feats):
            super().__init__()
            self.W = nn.Linear(in_feats * 2, out_feats)
    
        def forward(self, g, h):
            with g.local_scope():
                g.ndata['h'] = h
                g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
                return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))

如果用户已有一个用于整个图的自定义消息传递模块,并且想将其用于块,则只需要按照如下的方法重写forward函数。
注意,以下代码在注释里保留了整图实现的语句,用户可以将用于块的语句和原先用于整图的语句进行比较。

.. code:: python

    class CustomGraphConv(nn.Module):
        def __init__(self, in_feats, out_feats):
            super().__init__()
            self.W = nn.Linear(in_feats * 2, out_feats)

        # h现在是输入和输出节点的特征张量对,而不是一个单独的特征张量

        # def forward(self, g, h):
        def forward(self, block, h):
            # with g.local_scope():
            with block.local_scope():
                # g.ndata['h'] = h
                h_src = h
                h_dst = h[:block.number_of_dst_nodes()]
                block.srcdata['h'] = h_src
                block.dstdata['h'] = h_dst
    
                # g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
                block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
    
                # return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))
                return self.W(torch.cat(
                    [block.dstdata['h'], block.dstdata['h_neigh']], 1))

通常,需要对用于整图的GNN模块进行如下调整以将其用于块作为输入的情况:

-  切片取输入特征的前几行,得到输出节点的特征。切片行数可以通过
59
60
61
62
63
   :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>` 获得。
-  如果原图只包含一种节点类型,对输入节点特征,将 :attr:`g.ndata <dgl.DGLGraph.ndata>` 替换为
   :attr:`block.srcdata <dgl.DGLGraph.srcdata>`;对于输出节点特征,将
   :attr:`g.ndata <dgl.DGLGraph.ndata>`  替换为
   :attr:`block.dstdata <dgl.DGLGraph.dstdata>`。
64
-  如果原图包含多种节点类型,对于输入节点特征,将
65
66
67
68
   :attr:`g.nodes <dgl.DGLGraph.nodes>` 替换为
   :attr:`block.srcnodes <dgl.DGLGraph.srcnodes>`;对于输出节点特征,将
   :attr:`g.nodes <dgl.DGLGraph.nodes>` 替换为
   :attr:`block.dstnodes <dgl.DGLGraph.dstnodes>`。
69
-  对于输入节点数量,将 :meth:`g.num_nodes <dgl.DGLGraph.num_nodes>` 替换为
70
   :meth:`block.number_of_src_nodes <dgl.DGLGraph.number_of_src_nodes>` ;
71
   对于输出节点数量,将 :meth:`g.num_nodes <dgl.DGLGraph.num_nodes>` 替换为
72
   :meth:`block.number_of_dst_nodes <dgl.DGLGraph.number_of_dst_nodes>` 。
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

异构图上的模型定制
~~~~~~~~~~~~~~~~~~~~

为异构图修改GNN模块的方法是类似的。例如,以下面用于全图的GNN模块为例:

.. code:: python

    class CustomHeteroGraphConv(nn.Module):
        def __init__(self, g, in_feats, out_feats):
            super().__init__()
            self.Ws = nn.ModuleDict()
            for etype in g.canonical_etypes:
                utype, _, vtype = etype
                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
            for ntype in g.ntypes:
                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])
    
        def forward(self, g, h):
            with g.local_scope():
                for ntype in g.ntypes:
                    g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
                    g.nodes[ntype].data['h_src'] = h[ntype]
                for etype in g.canonical_etypes:
                    utype, _, vtype = etype
                    g.update_all(
                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
                        etype=etype)
                    g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \
                        self.Ws[etype](g.nodes[vtype].data['h_neigh'])
                return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}

对于 ``CustomHeteroGraphConv``,原则是将 ``g.nodes`` 替换为 ``g.srcnodes`` 或
``g.dstnodes`` (根据需要输入还是输出节点的特征来选择)。

.. code:: python

    class CustomHeteroGraphConv(nn.Module):
        def __init__(self, g, in_feats, out_feats):
            super().__init__()
            self.Ws = nn.ModuleDict()
            for etype in g.canonical_etypes:
                utype, _, vtype = etype
                self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
            for ntype in g.ntypes:
                self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])
    
        def forward(self, g, h):
            with g.local_scope():
                for ntype in g.ntypes:
                    h_src, h_dst = h[ntype]
                    g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
                    g.srcnodes[ntype].data['h_src'] = h[ntype]
                for etype in g.canonical_etypes:
                    utype, _, vtype = etype
                    g.update_all(
                        fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
                        etype=etype)
                    g.dstnodes[vtype].data['h_dst'] = \
                        g.dstnodes[vtype].data['h_dst'] + \
                        self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])
                return {ntype: g.dstnodes[ntype].data['h_dst']
                        for ntype in g.ntypes}

实现能够处理同构图、二分图和块的模块
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

DGL中所有的消息传递模块(参见 :ref:`apinn`)都能够处理同构图、
单向二分图(包含两种节点类型和一种边类型)和包含一种边类型的块。
本质上,内置的DGL神经网络模块的输入图及特征必须满足下列情况之一:

-  如果输入特征是一个张量对,则输入图必须是一个单向二分图
-  如果输入特征是一个单独的张量且输入图是一个块,则DGL会自动将输入节点特征前一部分设为输出节点的特征。
-  如果输入特征是一个单独的张量且输入图不是块,则输入图必须是同构图。

例如,下面的代码是 :class:`dgl.nn.pytorch.SAGEConv` 的简化版(DGL同样支持它在MXNet和TensorFlow后端里的实现)。
代码里移除了归一化,且只考虑平均聚合函数的情况。

.. code:: python

    import dgl.function as fn
    class SAGEConv(nn.Module):
        def __init__(self, in_feats, out_feats):
            super().__init__()
            self.W = nn.Linear(in_feats * 2, out_feats)
    
        def forward(self, g, h):
            if isinstance(h, tuple):
                h_src, h_dst = h
            elif g.is_block:
                h_src = h
                h_dst = h[:g.number_of_dst_nodes()]
            else:
                h_src = h_dst = h
                 
            g.srcdata['h'] = h_src
            g.dstdata['h'] = h_dst
            g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))
            return F.relu(
                self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))

:ref:`guide_cn-nn` 提供了对 :class:`dgl.nn.pytorch.SAGEConv` 代码的详细解读,
其适用于单向二分图、同构图和块。