Unverified Commit e1f663f8 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Misc] Add comments and highlights to examples/multigpu/node_classification_sage.py (#5956)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 43d49c1c
"""
This script trains and tests a GraphSAGE model for node classification on
multiple GPUs with distributed data-parallel training (DDP).
Before reading this example, please familiar yourself with graphsage node
classification using neighbor sampling by reading the example in the
`examples/sampling/node_classification.py`
This flowchart describes the main functional sequence of the provided example.
main
├───> Load and preprocess dataset
└───> run (multiprocessing)
├───> Init process group and build distributed SAGE model (HIGHLIGHT)
├───> train
│ │
│ ├───> NeighborSampler
│ │
│ └───> Training loop
│ │
│ ├───> SAGE.forward
│ │
│ └───> Collect validation accuracy (HIGHLIGHT)
└───> layerwise_infer
└───> SAGE.inference
├───> MultiLayerFullNeighborSampler
└───> Use a shared output tensor
"""
import argparse
import os
import time
......@@ -27,7 +62,7 @@ class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean
# Three-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
......@@ -57,11 +92,11 @@ class SAGE(nn.Module):
shuffle=False,
drop_last=False,
num_workers=0,
use_ddp=True,
use_ddp=True, # use DDP
use_uva=use_uva,
)
# in order to prevent running out of GPU memory, allocate a
# shared output tensor 'y' in host memory
# In order to prevent running out of GPU memory, allocate a shared
# output tensor 'y' in host memory.
y = shared_tensor(
(
g.num_nodes(),
......@@ -78,9 +113,9 @@ class SAGE(nn.Module):
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
# non_blocking (with pinned memory) to accelerate data transfer
# Non_blocking (with pinned memory) to accelerate data transfer
y[output_nodes] = h.to(y.device, non_blocking=True)
# make sure all GPUs are done writing to 'y'
# Use a barrier to make sure all GPUs are done writing to 'y'
dist.barrier()
g.ndata["h"] = y if use_uva else y.to(device)
......@@ -117,7 +152,7 @@ def layerwise_infer(
acc = MF.accuracy(
pred, labels, task="multiclass", num_classes=num_classes
)
print("Test accuracy {:.4f}".format(acc.item()))
print(f"Test accuracy {acc.item():.4f}")
def train(
......@@ -132,6 +167,7 @@ def train(
use_uva,
num_epochs,
):
# Instantiate a neighbor sampler
sampler = NeighborSampler(
[10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"]
)
......@@ -144,7 +180,7 @@ def train(
shuffle=True,
drop_last=False,
num_workers=0,
use_ddp=True,
use_ddp=True, # To split the set for each process
use_uva=use_uva,
)
val_dataloader = DataLoader(
......@@ -164,36 +200,64 @@ def train(
t0 = time.time()
model.train()
total_loss = 0
for it, (_, _, blocks) in enumerate(train_dataloader):
for it, (input_nodes, output_nodes, blocks) in enumerate(
train_dataloader
):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]
y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()
opt.step() # Gradients are synchronized in DDP
total_loss += loss
#####################################################################
# (HIGHLIGHT) Collect accuracy values from sub-processes and obtain
# overall accuracy.
#
# `torch.distributed.reduce` is used to reduce tensors from all the
# sub-processes to a specified process, ReduceOp.SUM is used by default.
#
# Other multiprocess functions supported by the backend are also
# available. Please refer to
# https://pytorch.org/docs/stable/distributed.html
# for more information.
#####################################################################
acc = (
evaluate(model, g, num_classes, val_dataloader).to(device) / nprocs
)
t1 = time.time()
dist.reduce(acc, 0)
# Reduce `acc` tensors to process 0.
dist.reduce(tensor=acc, dst=0)
if proc_id == 0:
print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} | "
"Time {:.4f}".format(
epoch, total_loss / (it + 1), acc.item(), t1 - t0
)
f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | "
f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
)
def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
# find corresponding device for my rank
# Find corresponding device for current process.
device = devices[proc_id]
torch.cuda.set_device(device)
# initialize process group and unpack data for sub-processes
#########################################################################
# (HIGHLIGHT) Build a data-parallel distributed GraphSAGE model.
#
# DDP in PyTorch provides data parallelism across the devices specified
# by the `process_group`. Gradients are synchronized across each model
# replica.
#
# To prepare a training sub-process, there are four steps involved:
# 1. Initialize the process group
# 2. Unpack data for the sub-process.
# 3. Instantiate a GraphSAGE model on the corresponding device.
# 4. Parallelize the model with `DistributedDataParallel`.
#
# For the detailed usage of `DistributedDataParallel`, please refer to
# PyTorch documentation.
#########################################################################
dist.init_process_group(
backend="nccl",
backend="nccl", # Use NCCL backend for distributed GPU training
init_method="tcp://127.0.0.1:12345",
world_size=nprocs,
rank=proc_id,
......@@ -202,14 +266,16 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
train_idx = train_idx.to(device)
val_idx = val_idx.to(device)
g = g.to(device if mode == "puregpu" else "cpu")
# create GraphSAGE model (distributed)
in_size = g.ndata["feat"].shape[1]
model = SAGE(in_size, 256, num_classes).to(device)
model = DistributedDataParallel(
model, device_ids=[device], output_device=device
)
# training + testing
# Training.
use_uva = mode == "mixed"
if proc_id == 0:
print("Training...")
train(
proc_id,
nprocs,
......@@ -222,8 +288,13 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
use_uva,
num_epochs,
)
# Testing.
if proc_id == 0:
print("Testing...")
layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva)
# cleanup process group
# Cleanup the process group.
dist.destroy_process_group()
......@@ -269,18 +340,19 @@ if __name__ == "__main__":
), f"Must have GPUs to enable multi-gpu training."
print(f"Training in {args.mode} mode using {nprocs} GPU(s)")
# load and preprocess dataset
# Load and preprocess the dataset.
print("Loading data")
dataset = AsNodePredDataset(
DglNodePropPredDataset(args.dataset_name, root=args.dataset_dir)
)
g = dataset[0]
# avoid creating certain graph formats in each sub-process to save momory
# Explicitly create desired graph formats before multi-processing to avoid
# redundant creation in each sub-process and to save memory.
g.create_formats_()
if args.dataset_name == "ogbn-arxiv":
g = dgl.to_bidirected(g, copy_ndata=True)
g = dgl.add_self_loop(g)
# thread limiting to avoid resource competition
# Thread limiting to avoid resource competition.
os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
data = (
dataset.num_classes,
......@@ -289,6 +361,7 @@ if __name__ == "__main__":
dataset.test_idx,
)
# To use DDP with n GPUs, spawn up n processes.
mp.spawn(
run,
args=(nprocs, devices, g, data, args.mode, args.num_epochs),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment