{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "gpuClass": "standard" }, "cells": [ { "cell_type": "markdown", "source": [ "# Graph Diffusion in Graph Neural Networks\n", "\n", "This tutorial first briefly introduces the diffusion process on graphs. It then illustrates how Graph Neural Networks can utilize this concept to enhance prediction power.\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/graph_diffusion.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/sparse/graph_diffusion.ipynb)" ], "metadata": { "id": "SfdsDpOK7yOT" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F6eQWmWn7lqh" }, "outputs": [], "source": [ "# Install required packages.\n", "import os\n", "import torch\n", "os.environ['TORCH'] = torch.__version__\n", "os.environ['DGLBACKEND'] = \"pytorch\"\n", "\n", "# Uncomment below to install required packages. If the CUDA version is not 11.6,\n", "# check the https://www.dgl.ai/pages/start.html to find the supported CUDA\n", "# version and corresponding command to install DGL.\n", "#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null\n", "#!pip install --upgrade scipy networkx > /dev/null\n", "\n", "try:\n", " import dgl\n", " installed = True\n", "except ImportError:\n", " installed = False\n", "print(\"DGL installed!\" if installed else \"Failed to install DGL!\")" ] }, { "cell_type": "markdown", "source": [ "## Graph Diffusion\n", "\n", "Diffusion describes the process of substances moving from one region to another. In the context of graph, the diffusing substances (e.g., real-value signals) travel along edges from nodes to nodes.\n", "\n", "Mathematically, let $\\vec x$ be the vector of node signals, then a graph diffusion operation can be defined as:\n", "\n", "$$\n", "\\vec{y} = \\tilde{A} \\vec{x}\n", "$$\n", "\n", ", where $\\tilde{A}$ is the **diffusion matrix** that is typically derived from the adjacency matrix of the graph. Although the selection of diffusion matrices may vary, the diffusion matrix is typically sparse and $\\tilde{A} \\vec{x}$ is thus a sparse-dense matrix multiplication.\n", "\n", "Let us understand it more with a simple example. First, we obtain the adjacency matrix of the famous [Karate Club Network](https://en.wikipedia.org/wiki/Zachary%27s_karate_club)." ], "metadata": { "id": "iH6os3oFcyze" } }, { "cell_type": "code", "source": [ "import dgl\n", "import dgl.sparse as dglsp\n", "from dgl.data import KarateClubDataset\n", "\n", "# Get the graph from DGL's builtin dataset.\n", "dataset = KarateClubDataset()\n", "dgl_g = dataset[0]\n", "\n", "# Get its adjacency matrix.\n", "indices = torch.stack(dgl_g.edges())\n", "N = dgl_g.num_nodes()\n", "A = dglsp.spmatrix(indices, shape=(N, N))\n", "print(A.to_dense())" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_TnCECJmBKJE", "outputId": "d8b78f0b-3a1c-4a9e-bcc9-ed4df7b7b5b7" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[0., 1., 1., ..., 1., 0., 0.],\n", " [1., 0., 1., ..., 0., 0., 0.],\n", " [1., 1., 0., ..., 0., 1., 0.],\n", " ...,\n", " [1., 0., 0., ..., 0., 1., 1.],\n", " [0., 0., 1., ..., 1., 0., 1.],\n", " [0., 0., 0., ..., 1., 1., 0.]])\n" ] } ] }, { "cell_type": "markdown", "source": [ "We use the graph convolution matrix from Graph Convolution Networks as the diffusion matrix in this example. The graph convolution matrix is defined as:\n", "\n", "$$\\tilde{A} = \\hat{D}^{-\\frac{1}{2}}\\hat{A}\\hat{D}^{-\\frac{1}{2}}$$\n", "\n", "with $\\hat{A} = A + I$, where $A$ denotes the adjacency matrix and $I$ denotes the identity matrix, $\\hat{D}$ refers to the diagonal node degree matrix of $\\hat{A}$." ], "metadata": { "id": "wJMT4oHOCCqJ" } }, { "cell_type": "code", "source": [ "# Compute graph convolution matrix.\n", "I = dglsp.identity(A.shape)\n", "A_hat = A + I\n", "D_hat = dglsp.diag(A_hat.sum(dim=1))\n", "D_hat_invsqrt = D_hat ** -0.5\n", "A_tilde = D_hat_invsqrt @ A_hat @ D_hat_invsqrt\n", "print(A_tilde.to_dense())" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JyzctBGaC_O5", "outputId": "b03ef3dc-dcf5-494e-9191-30591d09f138" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[0.0588, 0.0767, 0.0731, ..., 0.0917, 0.0000, 0.0000],\n", " [0.0767, 0.1000, 0.0953, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0731, 0.0953, 0.0909, ..., 0.0000, 0.0836, 0.0000],\n", " ...,\n", " [0.0917, 0.0000, 0.0000, ..., 0.1429, 0.1048, 0.0891],\n", " [0.0000, 0.0000, 0.0836, ..., 0.1048, 0.0769, 0.0654],\n", " [0.0000, 0.0000, 0.0000, ..., 0.0891, 0.0654, 0.0556]])\n" ] } ] }, { "cell_type": "markdown", "source": [ "For node signals, we set all nodes but one to be zero." ], "metadata": { "id": "geYvWuUkDbiL" } }, { "cell_type": "code", "source": [ "# Initial node signals. All nodes except one are set to zero.\n", "X = torch.zeros(N)\n", "X[0] = 5.\n", "\n", "# Number of diffusion steps.\n", "r = 8\n", "\n", "# Record the signals after each diffusion step.\n", "results = [X]\n", "for _ in range(r):\n", " X = A_tilde @ X\n", " results.append(X)" ], "metadata": { "id": "DXb0uKqXDZKb" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "The program below visualizes the diffusion process with animation. To play the animation, click the \"play\" icon. You will see how node features converge over time." ], "metadata": { "id": "TpqMz4muF2aO" } }, { "cell_type": "code", "source": [ "import matplotlib.pyplot as plt\n", "import networkx as nx\n", "from IPython.display import HTML\n", "from matplotlib import animation\n", "\n", "nx_g = dgl_g.to_networkx().to_undirected()\n", "pos = nx.spring_layout(nx_g)\n", "\n", "fig, ax = plt.subplots()\n", "plt.close()\n", "\n", "def animate(i):\n", " ax.cla()\n", " # Color nodes based on their features.\n", " nodes = nx.draw_networkx_nodes(nx_g, pos, ax=ax, node_size=200, node_color=results[i].tolist(), cmap=plt.cm.Blues)\n", " # Set boundary color of the nodes.\n", " nodes.set_edgecolor(\"#000000\")\n", " nx.draw_networkx_edges(nx_g, pos, ax=ax)\n", "\n", "ani = animation.FuncAnimation(fig, animate, frames=len(results), interval=1000)\n", "HTML(ani.to_jshtml())" ], "metadata": { "id": "eN3kmJ8nl7_z", "colab": { "base_uri": "https://localhost:8080/", "height": 386 }, "outputId": "be93263e-2283-4db7-caff-2e15e75ceb02" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ], "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
\n", "\n", "\n", "\n" ] }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "markdown", "source": [ "## Graph Diffusion in GNNs\n", "\n", "[Scalable Inception Graph Neural Networks (SIGN)](https://arxiv.org/abs/2004.11198) leverages multiple diffusion operators simultaneously. Formally, it is defined as:\n", "\n", "$$\n", "Z=\\sigma([X\\Theta_{0},A_1X\\Theta_{1},\\cdots,A_rX\\Theta_{r}])\\\\\n", "Y=\\xi(Z\\Omega)\n", "$$\n", "\n", "where:\n", "* $\\sigma$ and $\\xi$ are nonlinear activation functions.\n", "* $[\\cdot,\\cdots,\\cdot]$ is the concatenation operation.\n", "* $X\\in\\mathbb{R}^{n\\times d}$ is the input node feature matrix with $n$ nodes and $d$-dimensional feature vector per node.\n", "* $\\Theta_0,\\cdots,\\Theta_r\\in\\mathbb{R}^{d\\times d'}$ are learnable weight matrices.\n", "* $A_1,\\cdots, A_r\\in\\mathbb{R}^{n\\times n}$ are linear diffusion operators. In the example below, we consider $A^i$ for $A_i$, where $A$ is the convolution matrix of the graph.\n", "- $\\Omega\\in\\mathbb{R}^{d'(r+1)\\times c}$ is a learnable weight matrix and $c$ is the number of classes.\n", "\n", "The code below implements the diffusion function to compute $A_1X, A_2X, \\cdots, A_rX$ and the module that combines all the diffused node features." ], "metadata": { "id": "unL_mAj-TqC6" } }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "################################################################################\n", "# (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the feature\n", "# diffusion in SIGN laconically.\n", "################################################################################\n", "def sign_diffusion(A, X, r):\n", " # Perform the r-hop diffusion operation.\n", " X_sign = [X]\n", " for i in range(r):\n", " # A^i X\n", " X = A @ X\n", " X_sign.append(X)\n", " return X_sign\n", "\n", "class SIGN(nn.Module):\n", " def __init__(self, in_size, out_size, r, hidden_size=256):\n", " super().__init__()\n", " self.theta = nn.ModuleList(\n", " [nn.Linear(in_size, hidden_size) for _ in range(r + 1)]\n", " )\n", " self.omega = nn.Linear(hidden_size * (r + 1), out_size)\n", "\n", " def forward(self, X_sign):\n", " results = []\n", " for i in range(len(X_sign)):\n", " results.append(self.theta[i](X_sign[i]))\n", " Z = F.relu(torch.cat(results, dim=1))\n", " return self.omega(Z)" ], "metadata": { "id": "__U3Hsp_S0SR" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Training\n", "\n", "We train the SIGN model on [Cora dataset](https://docs.dgl.ai/en/latest/generated/dgl.data.CoraGraphDataset.html). The node features are diffused in the pre-processing stage." ], "metadata": { "id": "ngyh4-YZTkNY" } }, { "cell_type": "code", "source": [ "from dgl.data import CoraGraphDataset\n", "from torch.optim import Adam\n", "\n", "\n", "def evaluate(g, pred):\n", " label = g.ndata[\"label\"]\n", " val_mask = g.ndata[\"val_mask\"]\n", " test_mask = g.ndata[\"test_mask\"]\n", "\n", " # Compute accuracy on validation/test set.\n", " val_acc = (pred[val_mask] == label[val_mask]).float().mean()\n", " test_acc = (pred[test_mask] == label[test_mask]).float().mean()\n", " return val_acc, test_acc\n", "\n", "\n", "def train(model, g, X_sign):\n", " label = g.ndata[\"label\"]\n", " train_mask = g.ndata[\"train_mask\"]\n", " optimizer = Adam(model.parameters(), lr=3e-3)\n", "\n", " for epoch in range(10):\n", " # Switch the model to training mode.\n", " model.train()\n", "\n", " # Forward.\n", " logits = model(X_sign)\n", "\n", " # Compute loss with nodes in training set.\n", " loss = F.cross_entropy(logits[train_mask], label[train_mask])\n", "\n", " # Backward.\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Switch the model to evaluating mode.\n", " model.eval()\n", "\n", " # Compute prediction.\n", " logits = model(X_sign)\n", " pred = logits.argmax(1)\n", "\n", " # Evaluate the prediction.\n", " val_acc, test_acc = evaluate(g, pred)\n", " print(\n", " f\"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test\"\n", " f\" acc: {test_acc:.3f}\"\n", " )\n", "\n", "\n", "# If CUDA is available, use GPU to accelerate the training, use CPU\n", "# otherwise.\n", "dev = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Load graph from the existing dataset.\n", "dataset = CoraGraphDataset()\n", "g = dataset[0].to(dev)\n", "\n", "# Create the sparse adjacency matrix A (note that W was used as the notation\n", "# for adjacency matrix in the original paper).\n", "indices = torch.stack(g.edges())\n", "N = g.num_nodes()\n", "A = dglsp.spmatrix(indices, shape=(N, N))\n", "\n", "# Calculate the graph convolution matrix.\n", "I = dglsp.identity(A.shape, device=dev)\n", "A_hat = A + I\n", "D_hat_invsqrt = dglsp.diag(A_hat.sum(dim=1)) ** -0.5\n", "A_hat = D_hat_invsqrt @ A_hat @ D_hat_invsqrt\n", "\n", "# 2-hop diffusion.\n", "r = 2\n", "X = g.ndata[\"feat\"]\n", "X_sign = sign_diffusion(A_hat, X, r)\n", "\n", "# Create SIGN model.\n", "in_size = X.shape[1]\n", "out_size = dataset.num_classes\n", "model = SIGN(in_size, out_size, r).to(dev)\n", "\n", "# Kick off training.\n", "train(model, g, X_sign)" ], "metadata": { "id": "58WnPtPvT2mx", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "19e86f6a-c7f1-4b40-8cfc-58a181fc30d7" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...\n", "Extracting file to /root/.dgl/cora_v2\n", "Finished data loading and preprocessing.\n", " NumNodes: 2708\n", " NumEdges: 10556\n", " NumFeats: 1433\n", " NumClasses: 7\n", " NumTrainingSamples: 140\n", " NumValidationSamples: 500\n", " NumTestSamples: 1000\n", "Done saving data into cached files.\n", "In epoch 0, loss: 1.946, val acc: 0.164, test acc: 0.200\n", "In epoch 1, loss: 1.937, val acc: 0.712, test acc: 0.690\n", "In epoch 2, loss: 1.926, val acc: 0.610, test acc: 0.595\n", "In epoch 3, loss: 1.914, val acc: 0.656, test acc: 0.640\n", "In epoch 4, loss: 1.898, val acc: 0.724, test acc: 0.726\n", "In epoch 5, loss: 1.880, val acc: 0.734, test acc: 0.753\n", "In epoch 6, loss: 1.859, val acc: 0.730, test acc: 0.746\n", "In epoch 7, loss: 1.834, val acc: 0.732, test acc: 0.743\n", "In epoch 8, loss: 1.807, val acc: 0.734, test acc: 0.746\n", "In epoch 9, loss: 1.776, val acc: 0.734, test acc: 0.745\n" ] } ] }, { "cell_type": "markdown", "source": [ "Check out the full example script [here](https://github.com/dmlc/dgl/blob/master/examples/sparse/sign.py). Learn more about how graph diffusion is used in other GNN models:\n", "\n", "* *Predict then Propagate: Graph Neural Networks meet Personalized PageRank* [paper](https://arxiv.org/abs/1810.05997) [code](https://github.com/dmlc/dgl/blob/master/examples/sparse/appnp.py)\n", "* *Combining Label Propagation and Simple Models Out-performs Graph Neural Networks* [paper](https://arxiv.org/abs/2010.13993) [code](https://github.com/dmlc/dgl/blob/master/examples/sparse/c_and_s.py)\n", "* *Simplifying Graph Convolutional Networks* [paper](https://arxiv.org/abs/1902.07153) [code](https://github.com/dmlc/dgl/blob/master/examples/sparse/sgc.py)\n", "* *Graph Neural Networks Inspired by Classical Iterative Algorithms* [paper](https://arxiv.org/pdf/2103.06064.pdf) [code](https://github.com/dmlc/dgl/blob/master/examples/sparse/twirls.py)" ], "metadata": { "id": "lI2Nms8PXq-y" } } ] }