"tests/vscode:/vscode.git/clone" did not exist on "4f499c7ffb5f6746c4ccd87b93b0ef09c32cf424"
node_classification.ipynb 12.3 KB
Newer Older
1
2
3
4
{
  "cells": [
    {
      "cell_type": "markdown",
5
6
7
      "metadata": {
        "id": "OxbY2KlG4ZfJ"
      },
8
9
10
11
12
13
14
15
16
17
18
19
20
      "source": [
        "# Node Classification\n",
        "This tutorial shows how to train a multi-layer GraphSAGE for node\n",
        "classification on ``ogbn-arxiv`` provided by [Open Graph\n",
        "Benchmark (OGB)](https://ogb.stanford.edu/). The dataset contains around\n",
        "170 thousand nodes and 1 million edges.\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/node_classification.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/stochastic_training/node_classification.ipynb)\n",
        "\n",
        "By the end of this tutorial, you will be able to\n",
        "\n",
        "-  Train a GNN model for node classification on a single GPU with DGL's\n",
        "   neighbor sampling components."
21
      ]
22
23
24
25
26
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mzZKrVVk6Y_8"
27
28
29
30
      },
      "source": [
        "## Install DGL package"
      ]
31
32
33
    },
    {
      "cell_type": "code",
34
35
36
37
38
      "execution_count": null,
      "metadata": {
        "id": "QcpjTazg6hEo"
      },
      "outputs": [],
39
40
41
42
43
44
45
46
      "source": [
        "# Install required packages.\n",
        "import os\n",
        "import torch\n",
        "import numpy as np\n",
        "os.environ['TORCH'] = torch.__version__\n",
        "os.environ['DGLBACKEND'] = \"pytorch\"\n",
        "\n",
47
        "# Install the CUDA version. If you want to install CPU version, please\n",
48
        "# refer to https://www.dgl.ai/pages/start.html.\n",
49
50
        "device = torch.device(\"cuda\")\n",
        "!pip install --pre dgl -f https://data.dgl.ai/wheels-test/cu121/repo.html\n",
51
52
53
54
55
56
57
58
59
        "\n",
        "try:\n",
        "    import dgl\n",
        "    import dgl.graphbolt as gb\n",
        "    installed = True\n",
        "except ImportError as error:\n",
        "    installed = False\n",
        "    print(error)\n",
        "print(\"DGL installed!\" if installed else \"DGL not found!\")"
60
      ]
61
62
63
    },
    {
      "cell_type": "markdown",
64
65
66
      "metadata": {
        "id": "XWdRZAM-51Cb"
      },
67
68
69
      "source": [
        "## Loading Dataset\n",
        "`ogbn-arxiv` is already prepared as ``BuiltinDataset`` in **GraphBolt**."
70
      ]
71
72
73
    },
    {
      "cell_type": "code",
74
      "execution_count": null,
75
76
77
      "metadata": {
        "id": "RnJkkSKhWiUG"
      },
78
79
80
81
      "outputs": [],
      "source": [
        "dataset = gb.BuiltinDataset(\"ogbn-arxiv\").load()"
      ]
82
83
84
85
86
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S8avoKBiXA9j"
87
88
89
90
      },
      "source": [
        "Dataset consists of graph, feature and tasks. You can get the training-validation-test set from the tasks. Seed nodes and corresponding labels are already stored in each training-validation-test set. Other metadata such as number of classes are also stored in the tasks. In this dataset, there is only one task: `node classification`."
      ]
91
92
93
    },
    {
      "cell_type": "code",
94
95
96
97
98
      "execution_count": null,
      "metadata": {
        "id": "IXGZmgIaXJWQ"
      },
      "outputs": [],
99
      "source": [
100
101
        "graph = dataset.graph.to(device)\n",
        "feature = dataset.feature.to(device)\n",
102
103
104
105
106
107
        "train_set = dataset.tasks[0].train_set\n",
        "valid_set = dataset.tasks[0].validation_set\n",
        "test_set = dataset.tasks[0].test_set\n",
        "task_name = dataset.tasks[0].metadata[\"name\"]\n",
        "num_classes = dataset.tasks[0].metadata[\"num_classes\"]\n",
        "print(f\"Task: {task_name}. Number of classes: {num_classes}\")"
108
      ]
109
110
111
    },
    {
      "cell_type": "markdown",
112
113
114
      "metadata": {
        "id": "y8yn77Kg6HkW"
      },
115
116
117
118
119
      "source": [
        "## How DGL Handles Computation Dependency¶\n",
        "The computation dependency for message passing of a single node can be described as a series of message flow graphs (MFG).\n",
        "\n",
        "![DGL Computation](https://data.dgl.ai/tutorial/img/bipartite.gif)"
120
      ]
121
122
123
    },
    {
      "cell_type": "markdown",
124
125
126
      "metadata": {
        "id": "q7GrcJTnZQjt"
      },
127
128
129
      "source": [
        "## Defining Neighbor Sampler and Data Loader in DGL\n",
        "\n",
130
        "DGL provides tools to iterate over the dataset in minibatches while generating the computation dependencies to compute their outputs with the MFGs above. For node classification, you can use `dgl.graphbolt.DataLoader` for iterating over the dataset. It accepts a data pipe that generates minibatches of nodes and their labels, sample neighbors for each node, and generate the computation dependencies in the form of MFGs. Feature fetching, block creation and copying to target device are also supported. All these operations are split into separate stages in the data pipe, so that you can customize the data pipeline by inserting your own operations.\n",
131
132
        "\n",
        "Let’s say that each node will gather messages from 4 neighbors on each layer. The code defining the data loader and neighbor sampler will look like the following.\n"
133
      ]
134
135
136
    },
    {
      "cell_type": "code",
137
      "execution_count": null,
138
139
140
      "metadata": {
        "id": "yQVYDO0ZbBvi"
      },
141
142
143
144
145
146
147
148
149
      "outputs": [],
      "source": [
        "def create_dataloader(itemset, shuffle):\n",
        "    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)\n",
        "    datapipe = datapipe.copy_to(device, extra_attrs=[\"seed_nodes\"])\n",
        "    datapipe = datapipe.sample_neighbor(graph, [4, 4])\n",
        "    datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
        "    return gb.DataLoader(datapipe)"
      ]
150
151
152
    },
    {
      "cell_type": "markdown",
153
154
155
      "metadata": {
        "id": "7Rp12SUhbEV1"
      },
156
      "source": [
157
        "You can iterate over the data loader and a `MiniBatch` object is yielded.\n",
158
        "\n"
159
      ]
160
161
162
    },
    {
      "cell_type": "code",
163
      "execution_count": null,
164
165
166
      "metadata": {
        "id": "V7vQiKj2bL_o"
      },
167
168
169
170
171
      "outputs": [],
      "source": [
        "data = next(iter(create_dataloader(train_set, shuffle=True)))\n",
        "print(data)"
      ]
172
173
174
175
176
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-eBuPnT-bS-o"
177
178
179
180
      },
      "source": [
        "You can get the input node IDs from MFGs."
      ]
181
182
183
    },
    {
      "cell_type": "code",
184
185
186
187
188
      "execution_count": null,
      "metadata": {
        "id": "bN4sgZqFbUvd"
      },
      "outputs": [],
189
190
191
192
      "source": [
        "mfgs = data.blocks\n",
        "input_nodes = mfgs[0].srcdata[dgl.NID]\n",
        "print(f\"Input nodes: {input_nodes}.\")"
193
      ]
194
195
196
    },
    {
      "cell_type": "markdown",
197
198
199
      "metadata": {
        "id": "fV6epnRxbZl4"
      },
200
201
202
203
      "source": [
        "## Defining Model\n",
        "Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The model can be written as follows:\n",
        "\n"
204
      ]
205
206
207
    },
    {
      "cell_type": "code",
208
209
210
211
212
      "execution_count": null,
      "metadata": {
        "id": "iKhEIL0Ccmwx"
      },
      "outputs": [],
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
      "source": [
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from dgl.nn import SAGEConv\n",
        "\n",
        "\n",
        "class Model(nn.Module):\n",
        "    def __init__(self, in_feats, h_feats, num_classes):\n",
        "        super(Model, self).__init__()\n",
        "        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type=\"mean\")\n",
        "        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type=\"mean\")\n",
        "        self.h_feats = h_feats\n",
        "\n",
        "    def forward(self, mfgs, x):\n",
        "        h = self.conv1(mfgs[0], x)\n",
        "        h = F.relu(h)\n",
        "        h = self.conv2(mfgs[1], h)\n",
        "        return h\n",
        "\n",
        "\n",
        "in_size = feature.size(\"node\", None, \"feat\")[0]\n",
        "model = Model(in_size, 64, num_classes).to(device)"
235
      ]
236
237
238
    },
    {
      "cell_type": "markdown",
239
240
241
      "metadata": {
        "id": "OGLN3kCcwCA8"
      },
242
243
244
245
      "source": [
        "## Defining Training Loop\n",
        "\n",
        "The following initializes the model and defines the optimizer.\n"
246
      ]
247
248
249
    },
    {
      "cell_type": "code",
250
      "execution_count": null,
251
252
253
      "metadata": {
        "id": "dET8i_hewLUi"
      },
254
255
256
257
      "outputs": [],
      "source": [
        "opt = torch.optim.Adam(model.parameters())"
      ]
258
259
260
261
262
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "leZvFP4GwMcq"
263
264
265
266
      },
      "source": [
        "When computing the validation score for model selection, usually you can also do neighbor sampling. We can just reuse our create_dataloader function to create two separate dataloaders for training and validation."
      ]
267
268
269
    },
    {
      "cell_type": "code",
270
      "execution_count": null,
271
272
273
      "metadata": {
        "id": "Gvd7vFWZwQI5"
      },
274
275
276
277
278
279
280
      "outputs": [],
      "source": [
        "train_dataloader = create_dataloader(train_set, shuffle=True)\n",
        "valid_dataloader = create_dataloader(valid_set, shuffle=False)\n",
        "\n",
        "import sklearn.metrics"
      ]
281
282
283
284
285
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nTIIfVMDwXqX"
286
287
288
289
      },
      "source": [
        "The following is a training loop that performs validation every epoch. It also saves the model with the best validation accuracy into a file."
      ]
290
291
292
    },
    {
      "cell_type": "code",
293
294
295
296
297
      "execution_count": null,
      "metadata": {
        "id": "wsfqhKUvwZEj"
      },
      "outputs": [],
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
      "source": [
        "import tqdm\n",
        "\n",
        "for epoch in range(10):\n",
        "    model.train()\n",
        "\n",
        "    with tqdm.tqdm(train_dataloader) as tq:\n",
        "        for step, data in enumerate(tq):\n",
        "            x = data.node_features[\"feat\"]\n",
        "            labels = data.labels\n",
        "\n",
        "            predictions = model(data.blocks, x)\n",
        "\n",
        "            loss = F.cross_entropy(predictions, labels)\n",
        "            opt.zero_grad()\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "\n",
        "            accuracy = sklearn.metrics.accuracy_score(\n",
        "                labels.cpu().numpy(),\n",
        "                predictions.argmax(1).detach().cpu().numpy(),\n",
        "            )\n",
        "\n",
        "            tq.set_postfix(\n",
        "                {\"loss\": \"%.03f\" % loss.item(), \"acc\": \"%.03f\" % accuracy},\n",
        "                refresh=False,\n",
        "            )\n",
        "\n",
        "    model.eval()\n",
        "\n",
        "    predictions = []\n",
        "    labels = []\n",
        "    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():\n",
        "        for data in tq:\n",
        "            x = data.node_features[\"feat\"]\n",
        "            labels.append(data.labels.cpu().numpy())\n",
        "            predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())\n",
        "        predictions = np.concatenate(predictions)\n",
        "        labels = np.concatenate(labels)\n",
        "        accuracy = sklearn.metrics.accuracy_score(labels, predictions)\n",
338
339
        "        print(\"Epoch {} Validation Accuracy {}\".format(epoch, accuracy))"
      ]
340
341
342
    },
    {
      "cell_type": "markdown",
343
344
345
      "metadata": {
        "id": "kmHnUI0QwfJ4"
      },
346
347
348
349
      "source": [
        "## Conclusion\n",
        "\n",
        "In this tutorial, you have learned how to train a multi-layer GraphSAGE with neighbor sampling.\n"
350
      ]
351
    }
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}