nn-construction.rst 3.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
.. _guide_cn-nn-construction:

3.1 DGL NN模块的构造函数
-----------------------------

:ref:`(English Version) <guide-nn-construction>`

构造函数完成以下几个任务:

1. 设置选项。
2. 注册可学习的参数或者子模块。
3. 初始化参数。

.. code::

    import torch.nn as nn

    from dgl.utils import expand_as_pair

    class SAGEConv(nn.Module):
        def __init__(self,
                     in_feats,
                     out_feats,
                     aggregator_type,
                     bias=True,
                     norm=None,
                     activation=None):
            super(SAGEConv, self).__init__()

            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
            self._out_feats = out_feats
            self._aggre_type = aggregator_type
            self.norm = norm
            self.activation = activation

在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。
对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。

除了数据维度,图神经网络的一个典型选项是聚合类型(``self._aggre_type``)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。
常用的聚合类型包括 ``mean``、 ``sum``、 ``max`` 和 ``min``。一些模块可能会使用更加复杂的聚合函数,比如 ``lstm``。

上面代码里的 ``norm`` 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化:
:math:`h_v = h_v / \lVert h_v \rVert_2`。

.. code::

            # 聚合类型:mean、max_pool、lstm、gcn
            if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
                raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
            if aggregator_type == 'max_pool':
                self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
            if aggregator_type == 'lstm':
                self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
            if aggregator_type in ['mean', 'max_pool', 'lstm']:
                self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
            self.reset_parameters()

注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 ``nn.Linear``、 ``nn.LSTM`` 等。
构造函数的最后调用了 ``reset_parameters()`` 进行权重初始化。

.. code::

        def reset_parameters(self):
            """重新初始化可学习的参数"""
            gain = nn.init.calculate_gain('relu')
            if self._aggre_type == 'max_pool':
                nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
            if self._aggre_type == 'lstm':
                self.lstm.reset_parameters()
            if self._aggre_type != 'gcn':
                nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)