node_classification.ipynb 12.3 KB
Newer Older
1
2
3
4
{
  "cells": [
    {
      "cell_type": "markdown",
5
6
7
      "metadata": {
        "id": "OxbY2KlG4ZfJ"
      },
8
9
10
11
12
13
14
15
16
17
18
19
20
      "source": [
        "# Node Classification\n",
        "This tutorial shows how to train a multi-layer GraphSAGE for node\n",
        "classification on ``ogbn-arxiv`` provided by [Open Graph\n",
        "Benchmark (OGB)](https://ogb.stanford.edu/). The dataset contains around\n",
        "170 thousand nodes and 1 million edges.\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/node_classification.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/stochastic_training/node_classification.ipynb)\n",
        "\n",
        "By the end of this tutorial, you will be able to\n",
        "\n",
        "-  Train a GNN model for node classification on a single GPU with DGL's\n",
        "   neighbor sampling components."
21
      ]
22
23
24
25
26
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mzZKrVVk6Y_8"
27
28
29
30
      },
      "source": [
        "## Install DGL package"
      ]
31
32
33
    },
    {
      "cell_type": "code",
34
35
36
37
38
      "execution_count": null,
      "metadata": {
        "id": "QcpjTazg6hEo"
      },
      "outputs": [],
39
40
41
42
43
44
45
46
      "source": [
        "# Install required packages.\n",
        "import os\n",
        "import torch\n",
        "import numpy as np\n",
        "os.environ['TORCH'] = torch.__version__\n",
        "os.environ['DGLBACKEND'] = \"pytorch\"\n",
        "\n",
47
48
49
50
51
        "# Install the CPU version in default. If you want to install CUDA version,\n",
        "# please refer to https://www.dgl.ai/pages/start.html and change runtime type\n",
        "# accordingly.\n",
        "device = torch.device(\"cpu\")\n",
        "!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n",
52
53
54
55
56
57
58
59
60
        "\n",
        "try:\n",
        "    import dgl\n",
        "    import dgl.graphbolt as gb\n",
        "    installed = True\n",
        "except ImportError as error:\n",
        "    installed = False\n",
        "    print(error)\n",
        "print(\"DGL installed!\" if installed else \"DGL not found!\")"
61
      ]
62
63
64
    },
    {
      "cell_type": "markdown",
65
66
67
      "metadata": {
        "id": "XWdRZAM-51Cb"
      },
68
69
70
      "source": [
        "## Loading Dataset\n",
        "`ogbn-arxiv` is already prepared as ``BuiltinDataset`` in **GraphBolt**."
71
      ]
72
73
74
    },
    {
      "cell_type": "code",
75
      "execution_count": null,
76
77
78
      "metadata": {
        "id": "RnJkkSKhWiUG"
      },
79
80
81
82
      "outputs": [],
      "source": [
        "dataset = gb.BuiltinDataset(\"ogbn-arxiv\").load()"
      ]
83
84
85
86
87
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S8avoKBiXA9j"
88
89
90
91
      },
      "source": [
        "Dataset consists of graph, feature and tasks. You can get the training-validation-test set from the tasks. Seed nodes and corresponding labels are already stored in each training-validation-test set. Other metadata such as number of classes are also stored in the tasks. In this dataset, there is only one task: `node classification`."
      ]
92
93
94
    },
    {
      "cell_type": "code",
95
96
97
98
99
      "execution_count": null,
      "metadata": {
        "id": "IXGZmgIaXJWQ"
      },
      "outputs": [],
100
      "source": [
101
102
        "graph = dataset.graph.to(device)\n",
        "feature = dataset.feature.to(device)\n",
103
104
105
106
107
108
        "train_set = dataset.tasks[0].train_set\n",
        "valid_set = dataset.tasks[0].validation_set\n",
        "test_set = dataset.tasks[0].test_set\n",
        "task_name = dataset.tasks[0].metadata[\"name\"]\n",
        "num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n",
        "print(f\"Task: {task_name}. Number of classes: {num_classes}\")"
109
      ]
110
111
112
    },
    {
      "cell_type": "markdown",
113
114
115
      "metadata": {
        "id": "y8yn77Kg6HkW"
      },
116
117
118
119
120
      "source": [
        "## How DGL Handles Computation Dependency¶\n",
        "The computation dependency for message passing of a single node can be described as a series of message flow graphs (MFG).\n",
        "\n",
        "![DGL Computation](https://data.dgl.ai/tutorial/img/bipartite.gif)"
121
      ]
122
123
124
    },
    {
      "cell_type": "markdown",
125
126
127
      "metadata": {
        "id": "q7GrcJTnZQjt"
      },
128
129
130
      "source": [
        "## Defining Neighbor Sampler and Data Loader in DGL\n",
        "\n",
131
        "DGL provides tools to iterate over the dataset in minibatches while generating the computation dependencies to compute their outputs with the MFGs above. For node classification, you can use `dgl.graphbolt.DataLoader` for iterating over the dataset. It accepts a data pipe that generates minibatches of nodes and their labels, sample neighbors for each node, and generate the computation dependencies in the form of MFGs. Feature fetching, block creation and copying to target device are also supported. All these operations are split into separate stages in the data pipe, so that you can customize the data pipeline by inserting your own operations.\n",
132
133
        "\n",
        "Let’s say that each node will gather messages from 4 neighbors on each layer. The code defining the data loader and neighbor sampler will look like the following.\n"
134
      ]
135
136
137
    },
    {
      "cell_type": "code",
138
      "execution_count": null,
139
140
141
      "metadata": {
        "id": "yQVYDO0ZbBvi"
      },
142
143
144
145
146
147
148
149
150
      "outputs": [],
      "source": [
        "def create_dataloader(itemset, shuffle):\n",
        "    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)\n",
        "    datapipe = datapipe.copy_to(device, extra_attrs=[\"seed_nodes\"])\n",
        "    datapipe = datapipe.sample_neighbor(graph, [4, 4])\n",
        "    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
        "    return gb.DataLoader(datapipe)"
      ]
151
152
153
    },
    {
      "cell_type": "markdown",
154
155
156
      "metadata": {
        "id": "7Rp12SUhbEV1"
      },
157
      "source": [
158
        "You can iterate over the data loader and a `MiniBatch` object is yielded.\n",
159
        "\n"
160
      ]
161
162
163
    },
    {
      "cell_type": "code",
164
      "execution_count": null,
165
166
167
      "metadata": {
        "id": "V7vQiKj2bL_o"
      },
168
169
170
171
172
      "outputs": [],
      "source": [
        "data = next(iter(create_dataloader(train_set, shuffle=True)))\n",
        "print(data)"
      ]
173
174
175
176
177
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-eBuPnT-bS-o"
178
179
180
181
      },
      "source": [
        "You can get the input node IDs from MFGs."
      ]
182
183
184
    },
    {
      "cell_type": "code",
185
186
187
188
189
      "execution_count": null,
      "metadata": {
        "id": "bN4sgZqFbUvd"
      },
      "outputs": [],
190
191
192
193
      "source": [
        "mfgs = data.blocks\n",
        "input_nodes = mfgs[0].srcdata[dgl.NID]\n",
        "print(f\"Input nodes: {input_nodes}.\")"
194
      ]
195
196
197
    },
    {
      "cell_type": "markdown",
198
199
200
      "metadata": {
        "id": "fV6epnRxbZl4"
      },
201
202
203
204
      "source": [
        "## Defining Model\n",
        "Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The model can be written as follows:\n",
        "\n"
205
      ]
206
207
208
    },
    {
      "cell_type": "code",
209
210
211
212
213
      "execution_count": null,
      "metadata": {
        "id": "iKhEIL0Ccmwx"
      },
      "outputs": [],
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
      "source": [
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from dgl.nn import SAGEConv\n",
        "\n",
        "\n",
        "class Model(nn.Module):\n",
        "    def __init__(self, in_feats, h_feats, num_classes):\n",
        "        super(Model, self).__init__()\n",
        "        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type=\"mean\")\n",
        "        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type=\"mean\")\n",
        "        self.h_feats = h_feats\n",
        "\n",
        "    def forward(self, mfgs, x):\n",
        "        h = self.conv1(mfgs[0], x)\n",
        "        h = F.relu(h)\n",
        "        h = self.conv2(mfgs[1], h)\n",
        "        return h\n",
        "\n",
        "\n",
        "in_size = feature.size(\"node\", None, \"feat\")[0]\n",
        "model = Model(in_size, 64, num_classes).to(device)"
236
      ]
237
238
239
    },
    {
      "cell_type": "markdown",
240
241
242
      "metadata": {
        "id": "OGLN3kCcwCA8"
      },
243
244
245
246
      "source": [
        "## Defining Training Loop\n",
        "\n",
        "The following initializes the model and defines the optimizer.\n"
247
      ]
248
249
250
    },
    {
      "cell_type": "code",
251
      "execution_count": null,
252
253
254
      "metadata": {
        "id": "dET8i_hewLUi"
      },
255
256
257
258
      "outputs": [],
      "source": [
        "opt = torch.optim.Adam(model.parameters())"
      ]
259
260
261
262
263
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "leZvFP4GwMcq"
264
265
266
267
      },
      "source": [
        "When computing the validation score for model selection, usually you can also do neighbor sampling. We can just reuse our create_dataloader function to create two separate dataloaders for training and validation."
      ]
268
269
270
    },
    {
      "cell_type": "code",
271
      "execution_count": null,
272
273
274
      "metadata": {
        "id": "Gvd7vFWZwQI5"
      },
275
276
277
278
279
280
281
      "outputs": [],
      "source": [
        "train_dataloader = create_dataloader(train_set, shuffle=True)\n",
        "valid_dataloader = create_dataloader(valid_set, shuffle=False)\n",
        "\n",
        "import sklearn.metrics"
      ]
282
283
284
285
286
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nTIIfVMDwXqX"
287
288
289
290
      },
      "source": [
        "The following is a training loop that performs validation every epoch. It also saves the model with the best validation accuracy into a file."
      ]
291
292
293
    },
    {
      "cell_type": "code",
294
295
296
297
298
      "execution_count": null,
      "metadata": {
        "id": "wsfqhKUvwZEj"
      },
      "outputs": [],
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
      "source": [
        "import tqdm\n",
        "\n",
        "for epoch in range(10):\n",
        "    model.train()\n",
        "\n",
        "    with tqdm.tqdm(train_dataloader) as tq:\n",
        "        for step, data in enumerate(tq):\n",
        "            x = data.node_features[\"feat\"]\n",
        "            labels = data.labels\n",
        "\n",
        "            predictions = model(data.blocks, x)\n",
        "\n",
        "            loss = F.cross_entropy(predictions, labels)\n",
        "            opt.zero_grad()\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            accuracy = sklearn.metrics.accuracy_score(\n",
        "                labels.cpu().numpy(),\n",
        "                predictions.argmax(1).detach().cpu().numpy(),\n",
        "            )\n",
        "\n",
        "            tq.set_postfix(\n",
        "                {\"loss\": \"%.03f\" % loss.item(), \"acc\": \"%.03f\" % accuracy},\n",
        "                refresh=False,\n",
        "            )\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    predictions = []\n",
        "    labels = []\n",
        "    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():\n",
        "        for data in tq:\n",
        "            x = data.node_features[\"feat\"]\n",
        "            labels.append(data.labels.cpu().numpy())\n",
        "            predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())\n",
        "        predictions = np.concatenate(predictions)\n",
        "        labels = np.concatenate(labels)\n",
        "        accuracy = sklearn.metrics.accuracy_score(labels, predictions)\n",
339
340
        "        print(\"Epoch {} Validation Accuracy {}\".format(epoch, accuracy))"
      ]
341
342
343
    },
    {
      "cell_type": "markdown",
344
345
346
      "metadata": {
        "id": "kmHnUI0QwfJ4"
      },
347
348
349
350
      "source": [
        "## Conclusion\n",
        "\n",
        "In this tutorial, you have learned how to train a multi-layer GraphSAGE with neighbor sampling.\n"
351
      ]
352
    }
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
378
}