data-process.rst 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
.. _guide_cn-data-pipeline-process:

4.3 处理数据
----------------

:ref:`(English Version) <guide-data-pipeline-process>`

用户可以在 ``process()`` 函数中实现数据处理。该函数假定原始数据已经位于 ``self.raw_dir`` 目录中。

图上的机器学习任务通常有三种类型:整图分类、节点分类和链接预测。本节将展示如何处理与这些任务相关的数据集。

本节重点介绍了处理图、特征和划分掩码的标准方法。用户指南将以内置数据集为例,并跳过从文件构建图的实现。
用户可以参考 :ref:`guide_cn-graph-external` 以查看如何从外部数据源构建图的完整指南。

处理整图分类数据集
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

整图分类数据集与用小批次训练的典型机器学习任务中的大多数数据集类似。
因此,需要将原始数据处理为 :class:`dgl.DGLGraph` 对象的列表和标签张量的列表。
此外,如果原始数据已被拆分为多个文件,则可以添加参数 ``split`` 以导入数据的特定部分。

下面以 :class:`~dgl.data.QM7bDataset` 为例:

.. code::

    from dgl.data import DGLDataset

    class QM7bDataset(DGLDataset):
        _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
               'datasets/qm7b.mat'
        _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
Mufei Li's avatar
Mufei Li committed
32

33
34
35
36
37
38
        def __init__(self, raw_dir=None, force_reload=False, verbose=False):
            super(QM7bDataset, self).__init__(name='qm7b',
                                              url=self._url,
                                              raw_dir=raw_dir,
                                              force_reload=force_reload,
                                              verbose=verbose)
Mufei Li's avatar
Mufei Li committed
39

40
41
42
43
        def process(self):
            mat_path = self.raw_path + '.mat'
            # 将数据处理为图列表和标签列表
            self.graphs, self.label = self._load_graph(mat_path)
Mufei Li's avatar
Mufei Li committed
44

45
46
        def __getitem__(self, idx):
            """ 通过idx获取对应的图和标签
Mufei Li's avatar
Mufei Li committed
47

48
49
50
51
            Parameters
            ----------
            idx : int
                Item index
Mufei Li's avatar
Mufei Li committed
52

53
54
55
56
57
            Returns
            -------
            (dgl.DGLGraph, Tensor)
            """
            return self.graphs[idx], self.label[idx]
Mufei Li's avatar
Mufei Li committed
58

59
60
61
62
63
64
65
66
67
68
        def __len__(self):
            """数据集中图的数量"""
            return len(self.graphs)

函数 ``process()`` 将原始数据处理为图列表和标签列表。用户必须实现 ``__getitem__(idx)`` 和  ``__len__()`` 以进行迭代。
DGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图,标签)``。
用户可以参考 `QM7bDataset源代码  <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/qm7b.html#QM7bDataset>`__
以获得 ``self._load_graph()`` 和 ``__getitem__`` 的详细信息。

用户还可以向类添加属性以指示一些有用的数据集信息。在 :class:`~dgl.data.QM7bDataset` 中,
Mufei Li's avatar
Mufei Li committed
69
用户可以添加属性 ``num_tasks`` 来指示此多任务数据集中的预测任务总数:
70
71
72
73

.. code::

    @property
Mufei Li's avatar
Mufei Li committed
74
    def num_tasks(self):
75
76
77
78
79
        """每个图的标签数,即预测任务数。"""
        return 14

在编写完这些代码之后,用户可以按如下所示的方式来使用 :class:`~dgl.data.QM7bDataset`:

Mufei Li's avatar
Mufei Li committed
80
.. code::
81
82
83
84

    import dgl
    import torch

85
    from dgl.dataloading import GraphDataLoader
Mufei Li's avatar
Mufei Li committed
86

87
88
    # 数据导入
    dataset = QM7bDataset()
Mufei Li's avatar
Mufei Li committed
89
90
    num_tasks = dataset.num_tasks

91
    # 创建 dataloaders
92
    dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
Mufei Li's avatar
Mufei Li committed
93

94
95
96
97
98
99
    # 训练
    for epoch in range(100):
        for g, labels in dataloader:
            # 用户自己的训练代码
            pass

100
训练整图分类模型的完整指南可以在 :ref:`guide_cn-training-graph-classification` 中找到。
101

102
有关整图分类数据集的更多示例,用户可以参考 :ref:`guide_cn-training-graph-classification`:
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

* :ref:`gindataset`

* :ref:`minigcdataset`

* :ref:`qm7bdata`

* :ref:`tudata`

处理节点分类数据集
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

与整图分类不同,节点分类通常在单个图上进行。因此数据集的划分是在图的节点集上进行。
DGL建议使用节点掩码来指定数据集的划分。
本节以内置数据集 `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ 为例:

119
120
121
122
123
此外,DGL推荐重新排列图的节点/边,使得相邻节点/边的ID位于邻近区间内。这个过程
可以提高节点/边的邻居的局部性,为后续在图上进行的计算与分析的性能改善提供可能。
DGL提供了名为 :func:`dgl.reorder_graph` 的API用于此优化。更多细节,请参考
下面例子中的 ``process()`` 的部分。

124
125
126
.. code::

    from dgl.data import DGLBuiltinDataset
127
    from dgl.data.utils import _get_dgl_url
Mufei Li's avatar
Mufei Li committed
128

129
130
131
132
133
134
    class CitationGraphDataset(DGLBuiltinDataset):
        _urls = {
            'cora_v2' : 'dataset/cora_v2.zip',
            'citeseer' : 'dataset/citeseer.zip',
            'pubmed' : 'dataset/pubmed.zip',
        }
Mufei Li's avatar
Mufei Li committed
135

136
137
138
139
140
141
142
143
144
145
        def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
            assert name.lower() in ['cora', 'citeseer', 'pubmed']
            if name.lower() == 'cora':
                name = 'cora_v2'
            url = _get_dgl_url(self._urls[name])
            super(CitationGraphDataset, self).__init__(name,
                                                       url=url,
                                                       raw_dir=raw_dir,
                                                       force_reload=force_reload,
                                                       verbose=verbose)
Mufei Li's avatar
Mufei Li committed
146

147
148
149
        def process(self):
            # 跳过一些处理的代码
            # === 跳过数据处理 ===
Mufei Li's avatar
Mufei Li committed
150

151
152
153
154
            # 构建图
            g = dgl.graph(graph)

            # 划分掩码
155
156
157
            g.ndata['train_mask'] = train_mask
            g.ndata['val_mask'] = val_mask
            g.ndata['test_mask'] = test_mask
158
159
160
161
162
163
164

            # 节点的标签
            g.ndata['label'] = torch.tensor(labels)

            # 节点的特征
            g.ndata['feat'] = torch.tensor(_preprocess_features(features),
                                           dtype=F.data_type_dict['float32'])
Mufei Li's avatar
Mufei Li committed
165
            self._num_tasks = onehot_labels.shape[1]
166
            self._labels = labels
167
168
            # 重排图以获得更优的局部性
            self._g = dgl.reorder_graph(g)
Mufei Li's avatar
Mufei Li committed
169

170
171
172
        def __getitem__(self, idx):
            assert idx == 0, "这个数据集里只有一个图"
            return self._g
Mufei Li's avatar
Mufei Li committed
173

174
175
176
177
178
179
180
181
182
183
184
185
186
        def __len__(self):
            return 1

为简便起见,这里省略了 ``process()`` 中的一些代码,以突出展示用于处理节点分类数据集的关键部分:划分掩码。
节点特征和节点的标签被存储在 ``g.ndata`` 中。详细的实现请参考
`CitationGraphDataset源代码 <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ 。

请注意,这里 ``__getitem__(idx)`` 和 ``__len__()`` 的实现也发生了变化,
这是因为节点分类任务通常只用一个图。掩码在PyTorch和TensorFlow中是bool张量,在MXNet中是float张量。

下面中使用 :class:`dgl.data.CitationGraphDataset` 的子类 :class:`dgl.data.CiteseerGraphDataset`
来演示如何使用用于节点分类的数据集:

Mufei Li's avatar
Mufei Li committed
187
.. code::
188
189
190
191

    # 导入数据
    dataset = CiteseerGraphDataset(raw_dir='')
    graph = dataset[0]
Mufei Li's avatar
Mufei Li committed
192

193
194
195
196
    # 获取划分的掩码
    train_mask = graph.ndata['train_mask']
    val_mask = graph.ndata['val_mask']
    test_mask = graph.ndata['test_mask']
Mufei Li's avatar
Mufei Li committed
197

198
199
    # 获取节点特征
    feats = graph.ndata['feat']
Mufei Li's avatar
Mufei Li committed
200

201
202
203
    # 获取标签
    labels = graph.ndata['label']

204
:ref:`guide_cn-training-node-classification` 提供了训练节点分类模型的完整指南。
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

有关节点分类数据集的更多示例,用户可以参考以下内置数据集:

* :ref:`citationdata`

* :ref:`corafulldata`

* :ref:`amazoncobuydata`

* :ref:`coauthordata`

* :ref:`karateclubdata`

* :ref:`ppidata`

* :ref:`redditdata`

* :ref:`sbmdata`

* :ref:`sstdata`

* :ref:`rdfdata`

处理链接预测数据集
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

链接预测数据集的处理与节点分类相似,数据集中通常只有一个图。

本节以内置的数据集 `KnowledgeGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__
为例,同时省略了详细的数据处理代码以突出展示处理链接预测数据集的关键部分:

.. code::

    # 创建链接预测数据集示例
    class KnowledgeGraphDataset(DGLBuiltinDataset):
        def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
            self._name = name
            self.reverse = reverse
            url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
            super(KnowledgeGraphDataset, self).__init__(name,
                                                        url=url,
                                                        raw_dir=raw_dir,
                                                        force_reload=force_reload,
                                                        verbose=verbose)
Mufei Li's avatar
Mufei Li committed
249

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        def process(self):
            # 跳过一些处理的代码
            # === 跳过数据处理 ===

            # 划分掩码
            g.edata['train_mask'] = train_mask
            g.edata['val_mask'] = val_mask
            g.edata['test_mask'] = test_mask

            # 边类型
            g.edata['etype'] = etype

            # 节点类型
            g.ndata['ntype'] = ntype
            self._g = g
Mufei Li's avatar
Mufei Li committed
265

266
267
268
        def __getitem__(self, idx):
            assert idx == 0, "这个数据集只有一个图"
            return self._g
Mufei Li's avatar
Mufei Li committed
269

270
271
272
273
274
275
276
277
        def __len__(self):
            return 1


如代码所示,图的 ``edata`` 存储了划分掩码。在
`KnowledgeGraphDataset 源代码 <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__
中可以查看完整的代码。下面使用 ``KnowledgeGraphDataset``的子类 :class:`dgl.data.FB15k237Dataset` 来做演示如何使用用于链路预测的数据集:

Mufei Li's avatar
Mufei Li committed
278
.. code::
279
280
281
282
283
284

    from dgl.data import FB15k237Dataset

    # 导入数据
    dataset = FB15k237Dataset()
    graph = dataset[0]
Mufei Li's avatar
Mufei Li committed
285

286
287
    # 获取训练集掩码
    train_mask = graph.edata['train_mask']
288
    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
289
290
291
292
293
    src, dst = graph.edges(train_idx)

    # 获取训练集中的边类型
    rel = graph.edata['etype'][train_idx]

294
有关训练链接预测模型的完整指南,请参见 :ref:`guide_cn-training-link-prediction`。
295
296
297
298
299
300

有关链接预测数据集的更多示例,请参考DGL的内置数据集:

* :ref:`kgdata`

* :ref:`bitcoinotcdata`