"This tutorial shows how to train a multi-layer GraphSAGE for node classification on the `ogbn-products` dataset provided by [Open Graph\n",
"Benchmark (OGB)](https://ogb.stanford.edu/). The dataset contains around 2.4 million nodes and 62 million edges.\n",
"\n",
"[](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/stochastic_training/multigpu_node_classification.ipynb)\n",
"By the end of this tutorial, you will be able to\n",
"\n",
"- Train a GNN model for node classification on multiple GPUs with DGL's neighbor sampling components. After learning how to use multiple GPUs, you will\n",
"be able to extend it to other scenarios such as link prediction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mzZKrVVk6Y_8"
},
"source": [
"## Install DGL package and other dependencies"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "QTCc1RrD_5Id"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
"Looking in links: https://data.dgl.ai/wheels-test/cu121/repo.html\n",
"Requirement already satisfied: dgl in /localscratch/dgl-3/python (2.1)\n",
"Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (1.24.4)\n",
"Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (1.11.4)\n",
"Requirement already satisfied: networkx>=2.1 in /usr/local/lib/python3.10/dist-packages (from dgl) (2.6.3)\n",
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (2.31.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from dgl) (4.66.1)\n",
"Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (5.9.4)\n",
"Requirement already satisfied: torchdata>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (0.7.0a0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (1.26.18)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (2023.11.17)\n",
"Requirement already satisfied: torch>=2 in /usr/local/lib/python3.10/dist-packages (from torchdata>=0.5.0->dgl) (2.2.0a0+81ea7a4)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.13.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.8.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.12)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.2)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (2023.12.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
"Requirement already satisfied: torchmetrics in /usr/local/lib/python3.10/dist-packages (1.3.0.post0)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (0.70.16)\n",
"Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (1.24.4)\n",
"Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (23.2)\n",
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (2.2.0a0+81ea7a4)\n",
"Requirement already satisfied: lightning-utilities>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (0.10.1)\n",
"Requirement already satisfied: dill>=0.3.8 in /usr/local/lib/python3.10/dist-packages (from multiprocess) (0.3.8)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (68.2.2)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.8.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.13.1)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2.6.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.1.2)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2023.12.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->torchmetrics) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->torchmetrics) (1.3.0)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"DGL installed!\n"
]
}
],
"source": [
"# Install required packages.\n",
"import os\n",
"import torch\n",
"os.environ['TORCH'] = torch.__version__\n",
"os.environ['DGLBACKEND'] = \"pytorch\"\n",
"\n",
"# Install the CUDA version. If you want to install CPU version, please\n",
"# refer to https://www.dgl.ai/pages/start.html.\n",
"print(\"DGL installed!\" if installed else \"DGL not found!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q7GrcJTnZQjt"
},
"source": [
"## Defining Neighbor Sampler and Data Loader in DGL\n",
"\n",
"The major difference from the previous tutorial is that we will use `DistributedItemSampler` instead of `ItemSampler` to sample mini-batches of nodes. `DistributedItemSampler` is a distributed version of `ItemSampler` that works with `DistributedDataParallel`. It is implemented as a wrapper around `ItemSampler` and will sample the same minibatch on all replicas. It also supports dropping the last non-full minibatch to avoid the need for padding.\n"
"As the different GPUs might process differing numbers of data points, we define a function to compute the exact average of values such as loss or accuracy in a\n",
"The evaluation function can be used to calculate the validation accuracy during training or the testing accuracy at the end of the training. The difference from\n",
"the previous tutorial is that we need to return the number of items processed\n",
" for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:\n",
" blocks = data.blocks\n",
" x = data.node_features[\"feat\"]\n",
" y.append(data.labels)\n",
" y_hats.append(model.module(blocks, x))\n",
"\n",
" res = MF.accuracy(\n",
" torch.cat(y_hats),\n",
" torch.cat(y),\n",
" task=\"multiclass\",\n",
" num_classes=num_classes,\n",
" )\n",
"\n",
" return res.to(device), sum(y_i.size(0) for y_i in y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kN5BbnR4HSU2"
},
"source": [
"## Training Loop\n",
"\n",
"The training loop is almost identical to the previous tutorial. In this tutorial, we explicitly disable uneven inputs coming from the dataloader, however, the Join Context Manager could be used to train possibly with incomplete batches at the end of epochs. Please refer to [this tutorial](https://pytorch.org/tutorials/advanced/generic_join.html) for more information."
" # We synchronize before measuring the epoch time.\n",
" torch.cuda.synchronize()\n",
" epoch_end = time.time()\n",
" if rank == 0:\n",
" print(\n",
" f\"Epoch {epoch:05d} | \"\n",
" f\"Average Loss {total_loss.item():.4f} | \"\n",
" f\"Accuracy {acc.item():.4f} | \"\n",
" f\"Time {epoch_end - epoch_start:.4f}\"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mA-Xu37uIHc4"
},
"source": [
"## Defining Training and Evaluation Procedures\n",
"\n",
"The following code defines the main function for each process. It is similar to the previous tutorial except that we need to initialize a distributed training context with `torch.distributed` and wrap the model with `torch.nn.parallel.DistributedDataParallel`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "sW__HeslIMTT"
},
"outputs": [],
"source": [
"def run(rank, world_size, devices, dataset):\n",
" # Set up multiprocessing environment.\n",
" device = devices[rank]\n",
" torch.cuda.set_device(device)\n",
" dist.init_process_group(\n",
" backend=\"nccl\", # Use NCCL backend for distributed GPU training\n",
" init_method=\"tcp://127.0.0.1:12345\",\n",
" world_size=world_size,\n",
" rank=rank,\n",
" )\n",
"\n",
" # Pin the graph and features in-place to enable GPU access.\n",