graph_transformer.ipynb 18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jv-tHPvR-JKa"
      },
      "source": [
        "# Graph Transformer in a Nutshell\n",
        "\n",
        "The **Transformer** [(Vaswani et al. 2017)](https://proceedings.neurips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html) has been proven an effective learning architecture in natural language processing and computer vision.\n",
        "Recently, researchers turns to explore the application of transformer in graph learning. They have achieved inital success on many practical tasks, e.g., graph property prediction.\n",
        "[Dwivedi et al. (2020)](https://arxiv.org/abs/2012.09699) firstly generalize the transformer neural architecture to graph-structured data. Here, we present how to build such a graph transformer with DGL's sparse matrix APIs.\n",
        "\n",
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/graph_transformer.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/graph_transformer.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Install required packages.\n",
        "import os\n",
        "import torch\n",
        "os.environ['TORCH'] = torch.__version__\n",
        "os.environ['DGLBACKEND'] = \"pytorch\"\n",
        "\n",
27
        "# Uncomment below to install required packages. If the CUDA version is not 11.8,\n",
28
29
        "# check the https://www.dgl.ai/pages/start.html to find the supported CUDA\n",
        "# version and corresponding command to install DGL.\n",
30
        "#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null\n",
31
32
33
34
35
36
37
38
39
40
        "#!pip install ogb >/dev/null\n",
        "\n",
        "try:\n",
        "    import dgl\n",
        "    installed = True\n",
        "except ImportError:\n",
        "    installed = False\n",
        "print(\"DGL installed!\" if installed else \"Failed to install DGL!\")"
      ],
      "metadata": {
41
        "id": "8wIJZQqODy-7"
42
43
      },
      "execution_count": null,
44
      "outputs": []
45
46
47
48
49
50
51
52
53
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nOpFdtLI-JKb"
      },
      "source": [
        "## Sparse Multi-head Attention\n",
        "\n",
54
        "Recall the all-pairs scaled-dot-product attention mechanism in vanillar Transformer:\n",
55
56
57
        "\n",
        "$$\\text{Attn}=\\text{softmax}(\\dfrac{QK^T} {\\sqrt{d}})V,$$\n",
        "\n",
58
        "The graph transformer (GT) model employs a Sparse Multi-head Attention block:\n",
59
60
61
        "\n",
        "$$\\text{SparseAttn}(Q, K, V, A) = \\text{softmax}(\\frac{(QK^T) \\circ A}{\\sqrt{d}})V,$$\n",
        "\n",
62
        "where $Q, K, V ∈\\mathbb{R}^{N\\times d}$ are query feature, key feature, and value feature, respectively. $A\\in[0,1]^{N\\times N}$ is the adjacency matrix of the input graph. $(QK^T)\\circ A$ means that the multiplication of query matrix and key matrix is followed by a Hadamard product (or element-wise multiplication) with the sparse adjacency matrix as illustrated in the figure below:\n",
63
64
65
        "\n",
        "<img src=\"https://drive.google.com/uc?id=1OgMAewLR3Z1vz5y4J8aPRSeaU3g8iQfX\" width=\"500\">\n",
        "\n",
66
67
68
        "Essentially, only the attention scores between connected nodes are computed according to the sparsity of $A$. This operation is also called *Sampled Dense Dense Matrix Multiplication (SDDMM)*.\n",
        "\n",
        "Enjoying the [batched SDDMM API](https://docs.dgl.ai/en/latest/generated/dgl.sparse.bsddmm.html) in DGL, we can parallel the computation on multiple attention heads (different representation subspaces).\n",
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dh7zc5v0-JKb"
      },
      "outputs": [],
      "source": [
        "import dgl\n",
        "import dgl.nn as dglnn\n",
        "import dgl.sparse as dglsp\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "\n",
        "from dgl.data import AsGraphPredDataset\n",
        "from dgl.dataloading import GraphDataLoader\n",
        "from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator\n",
        "from ogb.graphproppred.mol_encoder import AtomEncoder\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "class SparseMHA(nn.Module):\n",
        "    \"\"\"Sparse Multi-head Attention Module\"\"\"\n",
        "\n",
        "    def __init__(self, hidden_size=80, num_heads=8):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "        self.num_heads = num_heads\n",
        "        self.head_dim = hidden_size // num_heads\n",
        "        self.scaling = self.head_dim**-0.5\n",
        "\n",
        "        self.q_proj = nn.Linear(hidden_size, hidden_size)\n",
        "        self.k_proj = nn.Linear(hidden_size, hidden_size)\n",
        "        self.v_proj = nn.Linear(hidden_size, hidden_size)\n",
        "        self.out_proj = nn.Linear(hidden_size, hidden_size)\n",
        "\n",
        "    def forward(self, A, h):\n",
        "        N = len(h)\n",
        "        # [N, dh, nh]\n",
        "        q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
        "        q *= self.scaling\n",
        "        # [N, dh, nh]\n",
        "        k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
        "        # [N, dh, nh]\n",
        "        v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
        "\n",
        "        ######################################################################\n",
        "        # (HIGHLIGHT) Compute the multi-head attention with Sparse Matrix API\n",
        "        ######################################################################\n",
123
124
125
        "        attn = dglsp.bsddmm(A, q, k.transpose(1, 0))  # (sparse) [N, N, nh]\n",
        "        # Sparse softmax by default applies on the last sparse dimension.\n",
        "        attn = attn.softmax()  # (sparse) [N, N, nh]\n",
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        "        out = dglsp.bspmm(attn, v)  # [N, dh, nh]\n",
        "\n",
        "        return self.out_proj(out.reshape(N, -1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3_Fm6Lrx-JKc"
      },
      "source": [
        "## Graph Transformer Layer\n",
        "\n",
        "The GT layer is composed of Multi-head Attention, Batch Norm, and Feed-forward Network, connected by residual links as in vanilla transformer.\n",
        "\n",
        "<img src=\"https://drive.google.com/uc?id=1cm-Ijw7bUQIOkoTKn5MQ3m4-66JqCsMz\" width=\"300\">"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M6h7JVWT-JKd"
      },
      "outputs": [],
      "source": [
        "class GTLayer(nn.Module):\n",
        "    \"\"\"Graph Transformer Layer\"\"\"\n",
        "\n",
        "    def __init__(self, hidden_size=80, num_heads=8):\n",
        "        super().__init__()\n",
        "        self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)\n",
        "        self.batchnorm1 = nn.BatchNorm1d(hidden_size)\n",
        "        self.batchnorm2 = nn.BatchNorm1d(hidden_size)\n",
        "        self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)\n",
        "        self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)\n",
        "\n",
        "    def forward(self, A, h):\n",
        "        h1 = h\n",
        "        h = self.MHA(A, h)\n",
        "        h = self.batchnorm1(h + h1)\n",
        "\n",
        "        h2 = h\n",
        "        h = self.FFN2(F.relu(self.FFN1(h)))\n",
        "        h = h2 + h\n",
        "\n",
        "        return self.batchnorm2(h)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t40DhVjI-JKd"
      },
      "source": [
        "## Graph Transformer Model\n",
        "\n",
183
        "The GT model is constructed by stacking GT layers. The input positional encoding of vanilla transformer is replaced with Laplacian positional encoding [(Dwivedi et al. 2020)](https://arxiv.org/abs/2003.00982). For the graph-level prediction task, an extra pooler is stacked on top of GT layers to aggregate node feature of the same graph."
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UrjvEBrF-JKe"
      },
      "outputs": [],
      "source": [
        "class GTModel(nn.Module):\n",
        "    def __init__(\n",
        "        self,\n",
        "        out_size,\n",
        "        hidden_size=80,\n",
        "        pos_enc_size=2,\n",
        "        num_layers=8,\n",
        "        num_heads=8,\n",
        "    ):\n",
        "        super().__init__()\n",
        "        self.atom_encoder = AtomEncoder(hidden_size)\n",
        "        self.pos_linear = nn.Linear(pos_enc_size, hidden_size)\n",
        "        self.layers = nn.ModuleList(\n",
        "            [GTLayer(hidden_size, num_heads) for _ in range(num_layers)]\n",
        "        )\n",
        "        self.pooler = dglnn.SumPooling()\n",
        "        self.predictor = nn.Sequential(\n",
        "            nn.Linear(hidden_size, hidden_size // 2),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(hidden_size // 2, hidden_size // 4),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(hidden_size // 4, out_size),\n",
        "        )\n",
        "\n",
        "    def forward(self, g, X, pos_enc):\n",
219
        "        indices = torch.stack(g.edges())\n",
220
        "        N = g.num_nodes()\n",
221
        "        A = dglsp.spmatrix(indices, shape=(N, N))\n",
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        "        h = self.atom_encoder(X) + self.pos_linear(pos_enc)\n",
        "        for layer in self.layers:\n",
        "            h = layer(A, h)\n",
        "        h = self.pooler(g, h)\n",
        "\n",
        "        return self.predictor(h)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RdrPU18I-JKe"
      },
      "source": [
        "## Training\n",
        "\n",
238
239
240
        "We train the GT model on [ogbg-molhiv](https://ogb.stanford.edu/docs/graphprop/#ogbg-mol) benchmark. The Laplacian positional encoding of each graph is pre-computed (with the API [here](https://docs.dgl.ai/en/latest/generated/dgl.laplacian_pe.html)) as part of the input to the model.\n",
        "\n",
        "*Note that we down-sample the dataset to make this demo runs faster. See the* [*example script*](https://github.com/dmlc/dgl/blob/master/examples/sparse/graph_transformer.py) *for the performance on the full dataset.*"
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V41i0w-9-JKe",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "15343d1a-a32d-4677-d053-d9da96910f43"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Computing Laplacian PE:   1%|          | 25/4000 [00:00<00:16, 244.77it/s]/usr/local/lib/python3.8/dist-packages/dgl/backend/pytorch/tensor.py:52: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:250.)\n",
            "  return th.as_tensor(data, dtype=dtype)\n",
            "Computing Laplacian PE: 100%|██████████| 4000/4000 [00:13<00:00, 296.04it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 000, Loss: 0.2486, Val: 0.3082, Test: 0.3068\n",
            "Epoch: 001, Loss: 0.1695, Val: 0.4684, Test: 0.4572\n",
            "Epoch: 002, Loss: 0.1428, Val: 0.5887, Test: 0.4721\n",
            "Epoch: 003, Loss: 0.1237, Val: 0.6375, Test: 0.5010\n",
            "Epoch: 004, Loss: 0.1127, Val: 0.6628, Test: 0.4854\n",
            "Epoch: 005, Loss: 0.1047, Val: 0.6811, Test: 0.4983\n",
            "Epoch: 006, Loss: 0.0949, Val: 0.6751, Test: 0.5409\n",
            "Epoch: 007, Loss: 0.0901, Val: 0.6340, Test: 0.5357\n",
            "Epoch: 008, Loss: 0.0811, Val: 0.6717, Test: 0.5543\n",
            "Epoch: 009, Loss: 0.0643, Val: 0.7861, Test: 0.5628\n",
            "Epoch: 010, Loss: 0.0489, Val: 0.7319, Test: 0.5341\n",
            "Epoch: 011, Loss: 0.0340, Val: 0.7884, Test: 0.5299\n",
            "Epoch: 012, Loss: 0.0285, Val: 0.5887, Test: 0.4293\n",
            "Epoch: 013, Loss: 0.0361, Val: 0.5514, Test: 0.3419\n",
            "Epoch: 014, Loss: 0.0451, Val: 0.6795, Test: 0.4964\n",
            "Epoch: 015, Loss: 0.0429, Val: 0.7405, Test: 0.5527\n",
            "Epoch: 016, Loss: 0.0331, Val: 0.7859, Test: 0.4994\n",
            "Epoch: 017, Loss: 0.0177, Val: 0.6544, Test: 0.4457\n",
            "Epoch: 018, Loss: 0.0201, Val: 0.8250, Test: 0.6073\n",
            "Epoch: 019, Loss: 0.0093, Val: 0.7356, Test: 0.5561\n"
          ]
        }
      ],
      "source": [
        "@torch.no_grad()\n",
        "def evaluate(model, dataloader, evaluator, device):\n",
        "    model.eval()\n",
        "    y_true = []\n",
        "    y_pred = []\n",
        "    for batched_g, labels in dataloader:\n",
        "        batched_g, labels = batched_g.to(device), labels.to(device)\n",
        "        y_hat = model(batched_g, batched_g.ndata[\"feat\"], batched_g.ndata[\"PE\"])\n",
        "        y_true.append(labels.view(y_hat.shape).detach().cpu())\n",
        "        y_pred.append(y_hat.detach().cpu())\n",
        "    y_true = torch.cat(y_true, dim=0).numpy()\n",
        "    y_pred = torch.cat(y_pred, dim=0).numpy()\n",
        "    input_dict = {\"y_true\": y_true, \"y_pred\": y_pred}\n",
        "    return evaluator.eval(input_dict)[\"rocauc\"]\n",
        "\n",
        "\n",
        "def train(model, dataset, evaluator, device):\n",
        "    train_dataloader = GraphDataLoader(\n",
        "        dataset[dataset.train_idx],\n",
        "        batch_size=256,\n",
        "        shuffle=True,\n",
        "        collate_fn=collate_dgl,\n",
        "    )\n",
        "    valid_dataloader = GraphDataLoader(\n",
        "        dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl\n",
        "    )\n",
        "    test_dataloader = GraphDataLoader(\n",
        "        dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl\n",
        "    )\n",
        "    optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
        "    num_epochs = 20\n",
        "    scheduler = optim.lr_scheduler.StepLR(\n",
        "        optimizer, step_size=num_epochs, gamma=0.5\n",
        "    )\n",
        "    loss_fcn = nn.BCEWithLogitsLoss()\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        model.train()\n",
        "        total_loss = 0.0\n",
        "        for batched_g, labels in train_dataloader:\n",
        "            batched_g, labels = batched_g.to(device), labels.to(device)\n",
        "            logits = model(\n",
        "                batched_g, batched_g.ndata[\"feat\"], batched_g.ndata[\"PE\"]\n",
        "            )\n",
        "            loss = loss_fcn(logits, labels.float())\n",
        "            total_loss += loss.item()\n",
        "            optimizer.zero_grad()\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "        scheduler.step()\n",
        "        avg_loss = total_loss / len(train_dataloader)\n",
        "        val_metric = evaluate(model, valid_dataloader, evaluator, device)\n",
        "        test_metric = evaluate(model, test_dataloader, evaluator, device)\n",
        "        print(\n",
        "            f\"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, \"\n",
        "            f\"Val: {val_metric:.4f}, Test: {test_metric:.4f}\"\n",
        "        )\n",
        "\n",
        "\n",
        "# Training device.\n",
        "dev = torch.device(\"cpu\")\n",
        "# Uncomment the code below to train on GPU. Be sure to install DGL with CUDA support.\n",
        "#dev = torch.device(\"cuda:0\")\n",
        "\n",
        "# Load dataset.\n",
        "pos_enc_size = 8\n",
        "dataset = AsGraphPredDataset(\n",
        "    DglGraphPropPredDataset(\"ogbg-molhiv\", \"./data/OGB\")\n",
        ")\n",
        "evaluator = Evaluator(\"ogbg-molhiv\")\n",
        "\n",
        "# Down sample the dataset to make the tutorial run faster.\n",
        "import random\n",
        "random.seed(42)\n",
        "train_size = len(dataset.train_idx)\n",
        "val_size = len(dataset.val_idx)\n",
        "test_size = len(dataset.test_idx)\n",
        "dataset.train_idx = dataset.train_idx[\n",
        "    torch.LongTensor(random.sample(range(train_size), 2000))\n",
        "]\n",
        "dataset.val_idx = dataset.val_idx[\n",
        "    torch.LongTensor(random.sample(range(val_size), 1000))\n",
        "]\n",
        "dataset.test_idx = dataset.test_idx[\n",
        "    torch.LongTensor(random.sample(range(test_size), 1000))\n",
        "]\n",
        "\n",
        "# Laplacian positional encoding.\n",
        "indices = torch.cat([dataset.train_idx, dataset.val_idx, dataset.test_idx])\n",
        "for idx in tqdm(indices, desc=\"Computing Laplacian PE\"):\n",
        "    g, _ = dataset[idx]\n",
        "    g.ndata[\"PE\"] = dgl.laplacian_pe(g, k=pos_enc_size, padding=True)\n",
        "\n",
        "# Create model.\n",
        "out_size = dataset.num_tasks\n",
        "model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)\n",
        "\n",
        "# Kick off training.\n",
        "train(model, dataset, evaluator, dev)"
      ]
    }
  ],
  "metadata": {
    "language_info": {
      "name": "python"
    },
    "orig_nbformat": 4,
    "colab": {
399
      "provenance": []
400
401
402
403
404
405
406
407
408
409
    },
    "gpuClass": "standard",
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
410
}