Unverified Commit 2c143daf authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Doc] Colab tutorial for graph transformer (#5188)

* Created using Colaboratory

* add gt nblink

* Created using Colaboratory

* update

* Add nblink
parent 2a806231
{
"path": "../../../../notebooks/sparse/graph_transformer.ipynb"
}
......@@ -9,3 +9,4 @@ TODO(minjie): intro for the new library.
quickstart.nblink
graph_diffusion.nblink
graph_transformer.nblink
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Jv-tHPvR-JKa"
},
"source": [
"# Graph Transformer in a Nutshell\n",
"\n",
"The **Transformer** [(Vaswani et al. 2017)](https://proceedings.neurips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html) has been proven an effective learning architecture in natural language processing and computer vision.\n",
"Recently, researchers turns to explore the application of transformer in graph learning. They have achieved inital success on many practical tasks, e.g., graph property prediction.\n",
"[Dwivedi et al. (2020)](https://arxiv.org/abs/2012.09699) firstly generalize the transformer neural architecture to graph-structured data. Here, we present how to build such a graph transformer with DGL's sparse matrix APIs.\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/graph_transformer.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/graph_transformer.ipynb)"
]
},
{
"cell_type": "code",
"source": [
"# Install required packages.\n",
"import os\n",
"import torch\n",
"os.environ['TORCH'] = torch.__version__\n",
"os.environ['DGLBACKEND'] = \"pytorch\"\n",
"\n",
"# TODO(Steve): change to stable version.\n",
"# Uncomment below to install required packages.\n",
"#!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html >/dev/null\n",
"#!pip install ogb >/dev/null\n",
"\n",
"try:\n",
" import dgl\n",
" installed = True\n",
"except ImportError:\n",
" installed = False\n",
"print(\"DGL installed!\" if installed else \"Failed to install DGL!\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8wIJZQqODy-7",
"outputId": "d9a9400a-e600-4749-b786-2736af35a130"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"DGL installed!\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nOpFdtLI-JKb"
},
"source": [
"## Sparse Multi-head Attention\n",
"\n",
"Unlike the all-pairs scaled-dot-product attention mechanism in vanillar Transformer:\n",
"\n",
"$$\\text{Attn}=\\text{softmax}(\\dfrac{QK^T} {\\sqrt{d}})V,$$\n",
"\n",
"the graph transformer (GT) model employs a Sparse Multi-head Attention block for message passing:\n",
"\n",
"$$\\text{SparseAttn}(Q, K, V, A) = \\text{softmax}(\\frac{(QK^T) \\circ A}{\\sqrt{d}})V,$$\n",
"\n",
"where $Q, K, V ∈\\mathbb{R}^{N\\times d}$ are query feature, key feature, and value feature, respectively. $A\\in[0,1]^{N\\times N}$ is the adjacency matrix of the input graph. $(QK^T)\\circ A$ denotes Sampled Dense Dense Matrix Multiplication (SDDMM) as shown in the figure below:\n",
"\n",
"<img src=\"https://drive.google.com/uc?id=1OgMAewLR3Z1vz5y4J8aPRSeaU3g8iQfX\" width=\"500\">\n",
"\n",
"Only attention scores between connected nodes are computed according to the sparsity of $A$. Enjoying the [batched SDDMM API](https://docs.dgl.ai/en/latest/generated/dgl.sparse.bsddmm.html) in DGL, we can parallel the computation on multiple attention heads (different representation subspaces).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dh7zc5v0-JKb"
},
"outputs": [],
"source": [
"import dgl\n",
"import dgl.nn as dglnn\n",
"import dgl.sparse as dglsp\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"\n",
"from dgl.data import AsGraphPredDataset\n",
"from dgl.dataloading import GraphDataLoader\n",
"from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator\n",
"from ogb.graphproppred.mol_encoder import AtomEncoder\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"class SparseMHA(nn.Module):\n",
" \"\"\"Sparse Multi-head Attention Module\"\"\"\n",
"\n",
" def __init__(self, hidden_size=80, num_heads=8):\n",
" super().__init__()\n",
" self.hidden_size = hidden_size\n",
" self.num_heads = num_heads\n",
" self.head_dim = hidden_size // num_heads\n",
" self.scaling = self.head_dim**-0.5\n",
"\n",
" self.q_proj = nn.Linear(hidden_size, hidden_size)\n",
" self.k_proj = nn.Linear(hidden_size, hidden_size)\n",
" self.v_proj = nn.Linear(hidden_size, hidden_size)\n",
" self.out_proj = nn.Linear(hidden_size, hidden_size)\n",
"\n",
" def forward(self, A, h):\n",
" N = len(h)\n",
" # [N, dh, nh]\n",
" q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
" q *= self.scaling\n",
" # [N, dh, nh]\n",
" k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
" # [N, dh, nh]\n",
" v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)\n",
"\n",
" ######################################################################\n",
" # (HIGHLIGHT) Compute the multi-head attention with Sparse Matrix API\n",
" ######################################################################\n",
" attn = dglsp.bsddmm(A, q, k.transpose(1, 0)) # [N, N, nh]\n",
" attn = attn.softmax() # [N, N, nh]\n",
" out = dglsp.bspmm(attn, v) # [N, dh, nh]\n",
"\n",
" return self.out_proj(out.reshape(N, -1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3_Fm6Lrx-JKc"
},
"source": [
"## Graph Transformer Layer\n",
"\n",
"The GT layer is composed of Multi-head Attention, Batch Norm, and Feed-forward Network, connected by residual links as in vanilla transformer.\n",
"\n",
"<img src=\"https://drive.google.com/uc?id=1cm-Ijw7bUQIOkoTKn5MQ3m4-66JqCsMz\" width=\"300\">"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M6h7JVWT-JKd"
},
"outputs": [],
"source": [
"class GTLayer(nn.Module):\n",
" \"\"\"Graph Transformer Layer\"\"\"\n",
"\n",
" def __init__(self, hidden_size=80, num_heads=8):\n",
" super().__init__()\n",
" self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)\n",
" self.batchnorm1 = nn.BatchNorm1d(hidden_size)\n",
" self.batchnorm2 = nn.BatchNorm1d(hidden_size)\n",
" self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)\n",
" self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)\n",
"\n",
" def forward(self, A, h):\n",
" h1 = h\n",
" h = self.MHA(A, h)\n",
" h = self.batchnorm1(h + h1)\n",
"\n",
" h2 = h\n",
" h = self.FFN2(F.relu(self.FFN1(h)))\n",
" h = h2 + h\n",
"\n",
" return self.batchnorm2(h)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t40DhVjI-JKd"
},
"source": [
"## Graph Transformer Model\n",
"\n",
"The GT model is constructed by stacking GT layers.\n",
"\n",
"The input positional encoding of vanilla transformer is replaced with Laplacian positional encoding [(Dwivedi et al. 2020)](https://arxiv.org/abs/2003.00982).\n",
"\n",
"For the graph-level prediction task, an extra pooler is stacked on top of GT layers to aggregate node feature of the same graph."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UrjvEBrF-JKe"
},
"outputs": [],
"source": [
"class GTModel(nn.Module):\n",
" def __init__(\n",
" self,\n",
" out_size,\n",
" hidden_size=80,\n",
" pos_enc_size=2,\n",
" num_layers=8,\n",
" num_heads=8,\n",
" ):\n",
" super().__init__()\n",
" self.atom_encoder = AtomEncoder(hidden_size)\n",
" self.pos_linear = nn.Linear(pos_enc_size, hidden_size)\n",
" self.layers = nn.ModuleList(\n",
" [GTLayer(hidden_size, num_heads) for _ in range(num_layers)]\n",
" )\n",
" self.pooler = dglnn.SumPooling()\n",
" self.predictor = nn.Sequential(\n",
" nn.Linear(hidden_size, hidden_size // 2),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size // 2, hidden_size // 4),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size // 4, out_size),\n",
" )\n",
"\n",
" def forward(self, g, X, pos_enc):\n",
" src, dst = g.edges()\n",
" N = g.num_nodes()\n",
" A = dglsp.from_coo(dst, src, shape=(N, N))\n",
" h = self.atom_encoder(X) + self.pos_linear(pos_enc)\n",
" for layer in self.layers:\n",
" h = layer(A, h)\n",
" h = self.pooler(g, h)\n",
"\n",
" return self.predictor(h)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RdrPU18I-JKe"
},
"source": [
"## Training\n",
"\n",
"We train the GT model on [ogbg-molhiv](https://ogb.stanford.edu/docs/graphprop/#ogbg-mol) benchmark. The Laplacian positional encoding of each graph is pre-computed (with the API [here](https://docs.dgl.ai/en/latest/generated/dgl.laplacian_pe.html)) as part of the input to the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "V41i0w-9-JKe",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "15343d1a-a32d-4677-d053-d9da96910f43"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Computing Laplacian PE: 1%| | 25/4000 [00:00<00:16, 244.77it/s]/usr/local/lib/python3.8/dist-packages/dgl/backend/pytorch/tensor.py:52: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:250.)\n",
" return th.as_tensor(data, dtype=dtype)\n",
"Computing Laplacian PE: 100%|██████████| 4000/4000 [00:13<00:00, 296.04it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch: 000, Loss: 0.2486, Val: 0.3082, Test: 0.3068\n",
"Epoch: 001, Loss: 0.1695, Val: 0.4684, Test: 0.4572\n",
"Epoch: 002, Loss: 0.1428, Val: 0.5887, Test: 0.4721\n",
"Epoch: 003, Loss: 0.1237, Val: 0.6375, Test: 0.5010\n",
"Epoch: 004, Loss: 0.1127, Val: 0.6628, Test: 0.4854\n",
"Epoch: 005, Loss: 0.1047, Val: 0.6811, Test: 0.4983\n",
"Epoch: 006, Loss: 0.0949, Val: 0.6751, Test: 0.5409\n",
"Epoch: 007, Loss: 0.0901, Val: 0.6340, Test: 0.5357\n",
"Epoch: 008, Loss: 0.0811, Val: 0.6717, Test: 0.5543\n",
"Epoch: 009, Loss: 0.0643, Val: 0.7861, Test: 0.5628\n",
"Epoch: 010, Loss: 0.0489, Val: 0.7319, Test: 0.5341\n",
"Epoch: 011, Loss: 0.0340, Val: 0.7884, Test: 0.5299\n",
"Epoch: 012, Loss: 0.0285, Val: 0.5887, Test: 0.4293\n",
"Epoch: 013, Loss: 0.0361, Val: 0.5514, Test: 0.3419\n",
"Epoch: 014, Loss: 0.0451, Val: 0.6795, Test: 0.4964\n",
"Epoch: 015, Loss: 0.0429, Val: 0.7405, Test: 0.5527\n",
"Epoch: 016, Loss: 0.0331, Val: 0.7859, Test: 0.4994\n",
"Epoch: 017, Loss: 0.0177, Val: 0.6544, Test: 0.4457\n",
"Epoch: 018, Loss: 0.0201, Val: 0.8250, Test: 0.6073\n",
"Epoch: 019, Loss: 0.0093, Val: 0.7356, Test: 0.5561\n"
]
}
],
"source": [
"@torch.no_grad()\n",
"def evaluate(model, dataloader, evaluator, device):\n",
" model.eval()\n",
" y_true = []\n",
" y_pred = []\n",
" for batched_g, labels in dataloader:\n",
" batched_g, labels = batched_g.to(device), labels.to(device)\n",
" y_hat = model(batched_g, batched_g.ndata[\"feat\"], batched_g.ndata[\"PE\"])\n",
" y_true.append(labels.view(y_hat.shape).detach().cpu())\n",
" y_pred.append(y_hat.detach().cpu())\n",
" y_true = torch.cat(y_true, dim=0).numpy()\n",
" y_pred = torch.cat(y_pred, dim=0).numpy()\n",
" input_dict = {\"y_true\": y_true, \"y_pred\": y_pred}\n",
" return evaluator.eval(input_dict)[\"rocauc\"]\n",
"\n",
"\n",
"def train(model, dataset, evaluator, device):\n",
" train_dataloader = GraphDataLoader(\n",
" dataset[dataset.train_idx],\n",
" batch_size=256,\n",
" shuffle=True,\n",
" collate_fn=collate_dgl,\n",
" )\n",
" valid_dataloader = GraphDataLoader(\n",
" dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl\n",
" )\n",
" test_dataloader = GraphDataLoader(\n",
" dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl\n",
" )\n",
" optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
" num_epochs = 20\n",
" scheduler = optim.lr_scheduler.StepLR(\n",
" optimizer, step_size=num_epochs, gamma=0.5\n",
" )\n",
" loss_fcn = nn.BCEWithLogitsLoss()\n",
"\n",
" for epoch in range(num_epochs):\n",
" model.train()\n",
" total_loss = 0.0\n",
" for batched_g, labels in train_dataloader:\n",
" batched_g, labels = batched_g.to(device), labels.to(device)\n",
" logits = model(\n",
" batched_g, batched_g.ndata[\"feat\"], batched_g.ndata[\"PE\"]\n",
" )\n",
" loss = loss_fcn(logits, labels.float())\n",
" total_loss += loss.item()\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" scheduler.step()\n",
" avg_loss = total_loss / len(train_dataloader)\n",
" val_metric = evaluate(model, valid_dataloader, evaluator, device)\n",
" test_metric = evaluate(model, test_dataloader, evaluator, device)\n",
" print(\n",
" f\"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, \"\n",
" f\"Val: {val_metric:.4f}, Test: {test_metric:.4f}\"\n",
" )\n",
"\n",
"\n",
"# Training device.\n",
"dev = torch.device(\"cpu\")\n",
"# Uncomment the code below to train on GPU. Be sure to install DGL with CUDA support.\n",
"#dev = torch.device(\"cuda:0\")\n",
"\n",
"# Load dataset.\n",
"pos_enc_size = 8\n",
"dataset = AsGraphPredDataset(\n",
" DglGraphPropPredDataset(\"ogbg-molhiv\", \"./data/OGB\")\n",
")\n",
"evaluator = Evaluator(\"ogbg-molhiv\")\n",
"\n",
"# Down sample the dataset to make the tutorial run faster.\n",
"import random\n",
"random.seed(42)\n",
"train_size = len(dataset.train_idx)\n",
"val_size = len(dataset.val_idx)\n",
"test_size = len(dataset.test_idx)\n",
"dataset.train_idx = dataset.train_idx[\n",
" torch.LongTensor(random.sample(range(train_size), 2000))\n",
"]\n",
"dataset.val_idx = dataset.val_idx[\n",
" torch.LongTensor(random.sample(range(val_size), 1000))\n",
"]\n",
"dataset.test_idx = dataset.test_idx[\n",
" torch.LongTensor(random.sample(range(test_size), 1000))\n",
"]\n",
"\n",
"# Laplacian positional encoding.\n",
"indices = torch.cat([dataset.train_idx, dataset.val_idx, dataset.test_idx])\n",
"for idx in tqdm(indices, desc=\"Computing Laplacian PE\"):\n",
" g, _ = dataset[idx]\n",
" g.ndata[\"PE\"] = dgl.laplacian_pe(g, k=pos_enc_size, padding=True)\n",
"\n",
"# Create model.\n",
"out_size = dataset.num_tasks\n",
"model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)\n",
"\n",
"# Kick off training.\n",
"train(model, dataset, evaluator, dev)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4,
"colab": {
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment