walkthrough.ipynb 10.2 KB
Newer Older
1
2
3
4
{
  "cells": [
    {
      "cell_type": "markdown",
5
6
7
      "metadata": {
        "id": "e1qfiZMOJYYv"
      },
8
9
10
11
12
13
      "source": [
        "# Graphbolt Quick Walkthrough\n",
        "\n",
        "The tutorial provides a quick walkthrough of operators provided by the `dgl.graphbolt` package, and illustrates how to create a GNN datapipe with the package. To learn more details about Stochastic Training of GNNs, please read the [materials](https://docs.dgl.ai/tutorials/large/index.html) provided by DGL.\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb)"
14
      ]
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fWiaC1WaDE-W"
      },
      "outputs": [],
      "source": [
        "# Install required packages.\n",
        "import os\n",
        "import torch\n",
        "os.environ['TORCH'] = torch.__version__\n",
        "os.environ['DGLBACKEND'] = \"pytorch\"\n",
        "\n",
Rhett Ying's avatar
Rhett Ying committed
30
31
32
        "# Install the CPU version.\n",
        "device = torch.device(\"cpu\")\n",
        "!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n",
33
34
35
36
37
38
39
40
41
42
43
44
        "\n",
        "try:\n",
        "    import dgl.graphbolt as gb\n",
        "    installed = True\n",
        "except ImportError as error:\n",
        "    installed = False\n",
        "    print(error)\n",
        "print(\"DGL installed!\" if installed else \"DGL not found!\")"
      ]
    },
    {
      "cell_type": "markdown",
45
46
47
      "metadata": {
        "id": "8O7PfsY4sPoN"
      },
48
49
50
51
52
53
      "source": [
        "## Dataset\n",
        "\n",
        "The dataset has three primary components. *1*. An itemset, which can be iterated over as the training target. *2*. A sampling graph, which is used by the subgraph sampling algorithm to generate a subgraph. *3*. A feature store, which stores node, edge, and graph features.\n",
        "\n",
        "* The **Itemset** is created from iterable data or tuple of iterable data."
54
      ]
55
56
57
    },
    {
      "cell_type": "code",
58
59
60
61
62
      "execution_count": null,
      "metadata": {
        "id": "g73ZAbMQsSgV"
      },
      "outputs": [],
63
      "source": [
64
        "seeds = torch.tensor(\n",
65
66
67
68
        "    [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n",
        "     [9, 6], [0, 6], [8, 6], [7, 7], [7, 7], [4, 7], [6, 8], [5, 8], [9, 9],\n",
        "     [4, 9], [4, 9], [5, 9], [9, 9], [5, 9], [9, 9], [7, 9]]\n",
        ")\n",
69
        "item_set = gb.ItemSet(seeds, names=\"seeds\")\n",
70
        "print(list(item_set))"
71
      ]
72
73
74
75
76
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lqty9p4cs0OR"
77
78
79
80
      },
      "source": [
        "* The **SamplingGraph** is used by the subgraph sampling algorithm to generate a subgraph. In graphbolt, we provide a canonical solution, the FusedCSCSamplingGraph, which achieves state-of-the-art time and space efficiency on CPU sampling. However, this requires enough CPU memory to host all FusedCSCSamplingGraph objects in memory."
      ]
81
82
83
    },
    {
      "cell_type": "code",
84
85
86
87
88
      "execution_count": null,
      "metadata": {
        "id": "jDjY149xs3PI"
      },
      "outputs": [],
89
90
91
92
93
94
95
96
      "source": [
        "indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n",
        "indices = torch.tensor(\n",
        "    [7, 6, 1, 3, 2, 8, 1, 2, 1, 9, 0, 8, 7, 7, 4, 6, 5, 9, 4, 4, 5, 9, 5, 9, 7]\n",
        ")\n",
        "num_edges = 25\n",
        "eid = torch.arange(num_edges)\n",
        "edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n",
97
        "graph = gb.fused_csc_sampling_graph(indptr, indices, edge_attributes=edge_attributes)\n",
98
        "print(graph)"
99
      ]
100
101
102
103
104
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mNp2S2_Vs8af"
105
106
107
108
      },
      "source": [
        "* The **FeatureStore** is used to store node, edge, and graph features. In graphbolt, we provide the TorchBasedFeature and related optimizations, such as the GPUCachedFeature, for different use cases."
      ]
109
110
111
    },
    {
      "cell_type": "code",
112
113
114
115
116
      "execution_count": null,
      "metadata": {
        "id": "zIU6KWe1Sm2g"
      },
      "outputs": [],
117
118
119
120
121
122
123
124
125
126
127
128
129
      "source": [
        "num_nodes = 10\n",
        "num_edges = 25\n",
        "node_feature_data = torch.rand((num_nodes, 2))\n",
        "edge_feature_data = torch.rand((num_edges, 3))\n",
        "node_feature = gb.TorchBasedFeature(node_feature_data)\n",
        "edge_feature = gb.TorchBasedFeature(edge_feature_data)\n",
        "features = {\n",
        "    (\"node\", None, \"feat\") : node_feature,\n",
        "    (\"edge\", None, \"feat\") : edge_feature,\n",
        "}\n",
        "feature_store = gb.BasicFeatureStore(features)\n",
        "print(feature_store)"
130
      ]
131
132
133
    },
    {
      "cell_type": "markdown",
134
135
136
      "metadata": {
        "id": "Oh2ockWWoXQ0"
      },
137
138
139
140
141
142
      "source": [
        "## DataPipe\n",
        "\n",
        "The DataPipe in Graphbolt is an extension of the PyTorch DataPipe, but it is specifically designed to address the challenges of training graph neural networks (GNNs). Each stage of the data pipeline loads data from different sources and can be combined with other stages to create more complex data pipelines. The intermediate data will be stored in **MiniBatch** data packs.\n",
        "\n",
        "* **ItemSampler** iterates over input **Itemset** and create subsets."
143
      ]
144
145
146
    },
    {
      "cell_type": "code",
147
      "execution_count": null,
148
149
150
      "metadata": {
        "id": "XtqPDprrogR7"
      },
151
152
153
154
155
      "outputs": [],
      "source": [
        "datapipe = gb.ItemSampler(item_set, batch_size=3, shuffle=False)\n",
        "print(next(iter(datapipe)))"
      ]
156
157
158
159
160
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BjkAK37xopp1"
161
162
163
164
      },
      "source": [
        "* **NegativeSampler** generate negative samples and return a mix of positive and negative samples."
      ]
165
166
167
    },
    {
      "cell_type": "code",
168
      "execution_count": null,
169
170
171
      "metadata": {
        "id": "PrFpGoOGopJy"
      },
172
173
174
175
176
      "outputs": [],
      "source": [
        "datapipe = datapipe.sample_uniform_negative(graph, 1)\n",
        "print(next(iter(datapipe)))"
      ]
177
178
179
180
181
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fYO_oIwkpmb3"
182
183
184
185
      },
      "source": [
        "* **SubgraphSampler** samples a subgraph from a given set of nodes from a larger graph."
      ]
186
187
188
    },
    {
      "cell_type": "code",
189
190
191
192
193
      "execution_count": null,
      "metadata": {
        "id": "4UsY3PL3ppYV"
      },
      "outputs": [],
194
195
196
197
      "source": [
        "fanouts = torch.tensor([1])\n",
        "datapipe = datapipe.sample_neighbor(graph, [fanouts])\n",
        "print(next(iter(datapipe)))"
198
      ]
199
200
201
202
203
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0uIydsjUqMA0"
204
205
206
207
      },
      "source": [
        "* **FeatureFetcher** fetchs features for node/edge in graphbolt."
      ]
208
209
210
    },
    {
      "cell_type": "code",
211
      "execution_count": null,
212
213
214
      "metadata": {
        "id": "YAj8G7YBqO6G"
      },
215
216
217
218
219
      "outputs": [],
      "source": [
        "datapipe = datapipe.fetch_feature(feature_store, node_feature_keys=[\"feat\"], edge_feature_keys=[\"feat\"])\n",
        "print(next(iter(datapipe)))"
      ]
220
221
222
223
224
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hjBSLPRPrsD2"
225
226
227
228
      },
      "source": [
        "* Copy the data to the GPU for training on the GPU."
      ]
229
230
231
    },
    {
      "cell_type": "code",
232
      "execution_count": null,
233
234
235
      "metadata": {
        "id": "RofiZOUMqt_u"
      },
236
237
      "outputs": [],
      "source": [
Rhett Ying's avatar
Rhett Ying committed
238
        "datapipe = datapipe.copy_to(device=device)\n",
239
240
        "print(next(iter(datapipe)))"
      ]
241
242
243
    },
    {
      "cell_type": "markdown",
244
245
246
      "metadata": {
        "id": "xm9HnyHRvxXj"
      },
247
248
249
250
      "source": [
        "## Exercise: Node classification\n",
        "\n",
        "Similarly, the following Dataset is created for node classification, can you implement the data pipeline for the dataset?"
251
      ]
252
253
254
    },
    {
      "cell_type": "code",
255
256
257
258
259
      "execution_count": null,
      "metadata": {
        "id": "YV-mk-xAv78v"
      },
      "outputs": [],
260
261
262
263
264
      "source": [
        "# Dataset for node classification.\n",
        "num_nodes = 10\n",
        "nodes = torch.arange(num_nodes)\n",
        "labels = torch.tensor([1, 2, 0, 2, 2, 0, 2, 2, 2, 2])\n",
265
        "item_set = gb.ItemSet((nodes, labels), names=(\"seeds\", \"labels\"))\n",
266
267
268
269
270
271
272
273
        "\n",
        "indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n",
        "indices = torch.tensor(\n",
        "    [7, 6, 1, 3, 2, 8, 1, 2, 1, 9, 0, 8, 7, 7, 4, 6, 5, 9, 4, 4, 5, 9, 5, 9, 7]\n",
        ")\n",
        "eid = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n",
        "                    14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])\n",
        "edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n",
274
        "graph = gb.from_fused_csc(indptr, indices, None, None, edge_attributes, None)\n",
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        "\n",
        "num_nodes = 10\n",
        "num_edges = 25\n",
        "node_feature_data = torch.rand((num_nodes, 2))\n",
        "edge_feature_data = torch.rand((num_edges, 3))\n",
        "node_feature = gb.TorchBasedFeature(node_feature_data)\n",
        "edge_feature = gb.TorchBasedFeature(edge_feature_data)\n",
        "features = {\n",
        "    (\"node\", None, \"feat\") : node_feature,\n",
        "    (\"edge\", None, \"feat\") : edge_feature,\n",
        "}\n",
        "feature_store = gb.BasicFeatureStore(features)\n",
        "\n",
        "# Datapipe.\n",
        "...\n",
        "print(next(iter(datapipe)))"
291
292
293
294
295
296
297
298
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "BjkAK37xopp1"
299
      ],
300
301
302
303
304
305
306
307
308
309
      "gpuType": "T4",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
310
    }
311
312
313
  },
  "nbformat": 4,
  "nbformat_minor": 0
314
}