Unverified Commit 2d2ad71e authored by Yi-Chien Lin's avatar Yi-Chien Lin Committed by GitHub
Browse files

[Feature] ARGO: an easy-to-use runtime to improve GNN training performance on...


[Feature] ARGO: an easy-to-use runtime to improve GNN training performance on multi-core processors (#7003)
Co-authored-by: default avatarAndrzej Kotłowski <Andrzej.Kotlowski@intel.com>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 5472cd41
...@@ -6,6 +6,13 @@ The folder contains example implementations of selected research papers related ...@@ -6,6 +6,13 @@ The folder contains example implementations of selected research papers related
To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/). To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/).
## 2024
- <a name="labor"></a> Lin et al. ARGO: An Auto-Tuning Runtime System for Scalable GNN Training on Multi-Core Processor. [Paper link](https://arxiv.org/abs/2402.03671)
- Example code: [PyTorch](https://github.com/dmlc/dgl/tree/master/examples/pytorch/argo)
- Tags: semi-supervised node classification
## 2023 ## 2023
- <a name="labor"></a> Zheng Wang et al. From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited. [Paper link](https://arxiv.org/abs/2210.13339) - <a name="labor"></a> Zheng Wang et al. From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited. [Paper link](https://arxiv.org/abs/2210.13339)
......
# ARGO: An Auto-Tuning Runtime System for Scalable GNN Training on Multi-Core Processor
## Overview
Graph Neural Network (GNN) training suffers from low scalability on multi-core processors.
ARGO is a runtime system that offers scalable performance.
The figure below shows an example of GNN training on a Xeon 8380H platform with 112 cores.
Without ARGO, there is no performance improvement after applying more than 16 cores; we observe a similar scalability limit on a Xeon 6430L platform with 64 cores as well.
However, with ARGO enabled, we are able to scale over 64 cores, allowing ARGO to speedup GNN training (in terms of epoch time) by up to 4.30x and 3.32x on a Xeon 8380H and a Xeon 6430L, respectively.
![ARGO](https://github.com/dmlc/dgl/tree/master/examples/pytorch/argo/argo_scale.png)
This README includes how to:
1. [Installation](#1-installation)
2. [Run the example code](#2-running-the-example-GNN-program)
3. [Modify your own GNN program to enable ARGO.](#3-enabling-ARGO-on-your-own-GNN-program)
## 1. Installation
1. ARGO utilizes the scikit-optimize library for auto-tuning. Please install scikit-optimize to run ARGO:
```shell
conda install -c conda-forge "scikit-optimize>=0.9.0"
```
or
```shell
pip install scikit-optimize>=0.9
```
## 2. Running the example GNN program
### Usage
```shell
python main.py --dataset ogbn-products --sampler shadow --model sage
```
Important Arguments:
- `--dataset`: the training datasets. Available choices [ogbn-products, ogbn-papers100M, reddit, flickr, yelp]
- `--sampler`: the mini-batch sampling algorithm. Available choices [shadow, neighbor]
- `--model`: GNN model. Available choices [gcn, sage]
- `--layer`: number of GNN layers.
- `--fan_out`: number of fanout neighbors for each layer.
- `--hidden`: hidden feature dimension.
- `--batch_size`: the size of the mini-batch.
## 3. Enabling ARGO on your own GNN program
In this section, we provide a step-by-step tutorial on how to enable ARGO on a DGL program. We use the ```ogb_example.py``` file in this repo as an example.
> Note: we also provide the complete example file ```ogb_example_ARGO.py``` which followed the steps below to enable ARGO on ```ogb_example.py```.
1. First, include all necessary packages on top of the file. Please place your file and ```argo.py``` in the same directory.
```python
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
from argo import ARGO
```
2. Setup PyTorch Distributed Data Parallel (DDP).
1. Add the initialization function on top of the training program, and wrap the ```model``` with the DDP wrapper
```python
def train(...):
dist.init_process_group('gloo', rank=rank, world_size=world_size) # newly added
model = SAGE(...) # original code
model = DistributedDataParallel(model) # newly added
...
```
2. In the main program, add the following before launching the training function
```python
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29501'
mp.set_start_method('fork', force=True)
train(args, device, data) # original code for launching the training function
```
3. Enable ARGO by initializing the runtime system, and wrapping the training function
```python
runtime = ARGO(n_search = 15, epoch = args.num_epochs, batch_size = args.batch_size) #initialization
runtime.run(train, args=(args, device, data)) # wrap the training function
```
> ARGO takes three input paramters: number of searches ```n_search```, number of epochs, and the mini-batch size. Increasing ```n_search``` potentially leads to a better configuration with less epoch time; however, searching itself also causes extra overhead. We recommend setting ```n_search``` from 15 to 45 for an optimal overall performance. Details of ```n_search``` can be found in the paper.
4. Modify the input of the training function, by directly adding ARGO parameters after the original inputs.
This is the original function:
```python
def train(args, device, data):
```
Add ```rank, world_size, comp_core, load_core, counter, b_size, ep``` like this:
```python
def train(args, device, data, rank, world_size, comp_core, load_core, counter, b_size, ep):
```
5. Modify the ```dataloader``` function in the training function
```python
dataloader = dgl.dataloading.DataLoader(
g,
train_nid,
sampler,
batch_size=b_size, # modified
shuffle=True,
drop_last=False,
num_workers=len(load_core), # modified
use_ddp = True) # newly added
```
6. Enable core-binding by adding ```enable_cpu_affinity()``` before the training for-loop, and also change the number of epochs into the variable ```ep```:
```python
with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core):
for epoch in range(ep): # change num_epochs to ep
```
7. Last step! Load the model before training and save it afterward.
Original Program:
```python
with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core):
for epoch in range(ep):
... # training operations
```
Modified:
```python
PATH = "model.pt"
if counter[0] != 0:
checkpoint = th.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
with dataloader.enable_cpu_affinity(loader_cores=load_core, compute_cores=comp_core):
for epoch in range(ep):
... # training operations
dist.barrier()
if rank == 0:
th.save({'epoch': counter[0],
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, PATH)
```
8. Done! You can now run your GNN program with ARGO enabled.
```shell
python <your_code>.py
```
## Citation & Acknowledgement
This work has been supported by the U.S. National Science Foundation (NSF) under grants CCF-1919289/SPX-2333009, CNS-2009057 and OAC-2209563, and the Semiconductor Research Corporation (SRC).
```
@INPROCEEDINGS{argo-ipdps24,
author={Yi-Chien Lin and Yuyang Chen and Sameh Gobriel and Nilesh Jain and Gopi Krishna Jhaand and Viktor Prasanna},
booktitle={IEEE International Parallel and Distributed Processing Symposium (IPDPS)},
title={ARGO: An Auto-Tuning Runtime System for Scalable GNN Training on Multi-Core Processor},
year={2024}}
```
"""
ARGO: An Auto-Tuning Runtime System for Scalable GNN Training on Multi-Core Processor
--------------------------------------------
Graph Neural Network (GNN) training suffers from low scalability on multi-core CPUs.
Specificially, the performance often caps at 16 cores, and no improvement is observed when applying more than 16 cores.
ARGO is a runtime system that offers scalable performance by overlapping the computation and communication during GNN training.
With ARGO enabled, we are able to scale over 64 cores, allowing ARGO to speedup GNN training (in terms of epoch time) by up to 4.30x and 3.32x on a Xeon 8380H and a Xeon 6430L, respectively.
--------------------------------------------
Paper Link: https://arxiv.org/abs/2402.03671
"""
import time
from typing import Callable, List, Tuple
import dgl.multiprocessing as dmp
import numpy as np
import psutil
from skopt import gp_minimize
from skopt.space import Normalize
def transform(self, X):
X = np.asarray(X)
if self.is_int:
if np.any(np.round(X) > self.high):
raise ValueError(
"All integer values should" "be less than %f" % self.high
)
if np.any(np.round(X) < self.low):
raise ValueError(
"All integer values should" "be greater than %f" % self.low
)
else:
if np.any(X > self.high + self._eps):
raise ValueError("All values should" "be less than %f" % self.high)
if np.any(X < self.low - self._eps):
raise ValueError(
"All values should" "be greater than %f" % self.low
)
if (self.high - self.low) == 0.0:
return X * 0.0
if self.is_int:
return (np.round(X).astype(int) - self.low) / (self.high - self.low)
else:
return (X - self.low) / (self.high - self.low)
def inverse_transform(self, X):
X = np.asarray(X)
if np.any(X > 1.0 + self._eps):
raise ValueError("All values should be less than 1.0")
if np.any(X < 0.0 - self._eps):
raise ValueError("All values should be greater than 0.0")
X_orig = X * (self.high - self.low) + self.low
if self.is_int:
return np.round(X_orig).astype(int)
return X_orig
# This is a workaround for scikit-optimize's incompatibility with NumPy, which results in an error::
# AttributeError: module 'numpy' has no attribute 'int'
Normalize.transform = transform
Normalize.inverse_transform = inverse_transform
class ARGO:
def __init__(
self,
n_search=10,
epoch=200,
batch_size=4096,
space=[(2, 8), (1, 4), (1, 32)],
random_state=1,
):
"""
Initialization
Parameters
----------
n_search: int
Number of configuration searches the auto-tuner will conduct
epoch: int
Number of epochs of GNN training
batch_size: int
Size of the mini-batch
space: list[Tuple(int,int)]
Range of the search space; [range of processes, range of samplers for each process, range of trainers for each process]
random_state: int
Number of random initializations before searching
"""
self.n_search = n_search
self.epoch = epoch
self.batch_size = batch_size
self.space = space
self.random_state = random_state
self.acq_func = "EI"
self.counter = [0]
def core_binder(
self, num_cpu_proc: int, n_samp: int, n_train: int, rank: int
) -> Tuple[List[int], List[int]]:
"""
Core Binder
The Core Binder binds CPU cores to perform sampling (i.e., sampling cores) and model propagation (i.e., training cores).
The actual binding is done using the CPU affinity function in the data_loader.
The core_binder function here is used to produce the list of CPU IDs for the CPU affinity function.
Parameters
----------
num_cpu_proc: int
Number of processes instantiated
n_samp: int
Number of sampling cores for each process
n_train: int
Number of training cores for each process
rank: int
The rank of the current process
Returns: Tuple[list[int], list[int]]
-------
load_core: list[int]
For a given process rank, the load_core specifies a list of CPU core IDs to be used for sampling, the length of load_core = n_samp.
comp_core: list[int]
For a given process rank, the comp_core specifies a list of CPU core IDs to be used for training, the length of comp_core = n_comp.
.. note:: Each process is assigned with a unique list of sampling cores and training cores, and no CPU core will appear in two lists or more.
"""
load_core, comp_core = [], []
n = psutil.cpu_count(logical=False)
size = num_cpu_proc
num_of_samplers = n_samp
load_core = list(
range(n // size * rank, n // size * rank + num_of_samplers)
)
comp_core = list(
range(
n // size * rank + num_of_samplers,
n // size * rank + num_of_samplers + n_train,
)
)
return load_core, comp_core
def auto_tuning(self, train: Callable, args) -> List[int]:
"""
Auto-tuner
The auto-tuner runs Bayesian Optimization (BO) to search for the optimal configuration (number of processes, samplers, trainers).
During the search, the auto-tuner explores the design space by collecting the epoch time of various configurations.
Specifically, the exploration is done by feeding the Multi-Process Engine with various configurations, and record the epoch time.
After the searching is done, the optimal configuration will be used repeatedly until the end of model training.
Parameters
----------
train: Callable
The GNN training function.
args:
The inputs of the GNN training function.
Returns
-------
result: list[int]
The optimal configurations (which leads to the shortest epoch time) found by running BO.
- result[0]: number of processes to instantiate
- result[1]: number of sampling cores for each process
- result[2]: number of training cores for each process
"""
ep = 1
result = gp_minimize(
lambda x: self.mp_engine(x, train, args, ep),
dimensions=self.space,
n_calls=self.n_search,
random_state=self.random_state,
acq_func=self.acq_func,
)
return result
def mp_engine(self, x: List[int], train: Callable, args, ep: int) -> float:
"""
Multi-Process Engine (MP Engine)
The MP Engine launches multiple GNN training processes in parallel to overlap computation with communication.
Such an approach effectively improves the utilization of the memory bandwidth and the CPU cores.
The MP Engine also adjust the batch size according to the number of processes instantiated, so that the effective batch size remains the same as the original program without ARGO.
Parameters
----------
x: list[int]
Optimal configurations provided by the auto-tuner.
- x[0]: number of processes to instantiate
- x[1]: number of sampling cores for each process
- x[2]: number of training cores for each process
train: Callable
The GNN training function.
args:
The inputs of the GNN training function.
ep: int
number of epochs.
Returns
-------
t: float
The epoch time using the current configuration `x`.
"""
n_proc = x[0]
n_samp = x[1]
n_train = x[2]
n_total = psutil.cpu_count(logical=False)
if n_proc * (n_samp + n_train) > n_total: # handling corner cases
n_proc = 2
n_samp = 2
n_train = (n_total // n_proc) - n_samp
processes = []
cnt = self.counter
b_size = self.batch_size // n_proc # adjust batch size
tik = time.time()
for i in range(n_proc):
load_core, comp_core = self.core_binder(n_proc, n_samp, n_train, i)
p = dmp.Process(
target=train,
args=(*args, i, n_proc, comp_core, load_core, cnt, b_size, ep),
)
p.start()
processes.append(p)
for p in processes:
p.join()
t = time.time() - tik
self.counter[0] = self.counter[0] + 1
return t
def run(self, train, args):
"""
The "run" function launches ARGO to traing GNN model
Step 1: run the auto-tuner to search for the optimal configuration
Step 2: record the optimal configuration
Step 3: use the optimal configuration repeatedly until the end of the model training
Parameters
----------
train: Callable
The GNN training function.
args:
The inputs of the GNN training function.
"""
result = self.auto_tuning(train, args) # Step 1
x = result.x # Step 2
self.mp_engine(
x, train, args, ep=(self.epoch - self.n_search)
) # Step 3
import argparse
import os
import dgl
import dgl.nn as dglnn
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from argo import ARGO
from dgl.data import (
AsNodePredDataset,
FlickrDataset,
RedditDataset,
YelpDataset,
)
from dgl.dataloading import DataLoader, NeighborSampler, ShaDowKHopSampler
from ogb.nodeproppred import DglNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel
class GNN(nn.Module):
def __init__(
self, in_size, hid_size, out_size, num_layers=3, model_name="sage"
):
super().__init__()
self.layers = nn.ModuleList()
# GraphSAGE-mean
if model_name.lower() == "sage":
self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
for i in range(num_layers - 2):
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
# GCN
elif model_name.lower() == "gcn":
kwargs = {
"norm": "both",
"weight": True,
"bias": True,
"allow_zero_in_degree": True,
}
self.layers.append(dglnn.GraphConv(in_size, hid_size, **kwargs))
for i in range(num_layers - 2):
self.layers.append(
dglnn.GraphConv(hid_size, hid_size, **kwargs)
)
self.layers.append(dglnn.GraphConv(hid_size, out_size, **kwargs))
else:
raise NotImplementedError
self.dropout = nn.Dropout(0.5)
self.hid_size = hid_size
self.out_size = out_size
def forward(self, blocks, x):
h = x
if hasattr(blocks, "__len__"):
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
else:
for l, layer in enumerate(self.layers):
h = layer(blocks, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
def _train(**kwargs):
total_loss = 0
loader = kwargs["loader"]
model = kwargs["model"]
opt = kwargs["opt"]
load_core = kwargs["load_core"]
comp_core = kwargs["comp_core"]
device = torch.device("cpu")
with loader.enable_cpu_affinity(
loader_cores=load_core, compute_cores=comp_core
):
for it, (input_nodes, output_nodes, blocks) in enumerate(loader):
if hasattr(blocks, "__len__"):
x = blocks[0].srcdata["feat"].to(torch.float32)
y = blocks[-1].dstdata["label"]
else:
x = blocks.srcdata["feat"].to(torch.float32)
y = blocks.dstdata["label"]
if kwargs["device"] == "cpu": # for papers100M
y = y.type(torch.LongTensor)
y_hat = model(blocks, x)
else:
y = y.type(torch.LongTensor).to(device)
y_hat = model(blocks, x).to(device)
try:
loss = F.cross_entropy(
y_hat[: output_nodes.shape[0]], y[: output_nodes.shape[0]]
)
except:
loss = F.binary_cross_entropy_with_logits(
y_hat[: output_nodes.shape[0]].float(),
y[: output_nodes.shape[0]].float(),
reduction="sum",
)
opt.zero_grad()
loss.backward()
opt.step()
del input_nodes, output_nodes, blocks
total_loss += loss.item()
return total_loss
def train(
args, g, data, rank, world_size, comp_core, load_core, counter, b_size, ep
):
num_classes, train_idx = data
dist.init_process_group("gloo", rank=rank, world_size=world_size)
device = torch.device("cpu")
hidden = args.hidden
# create GraphSAGE model
in_size = g.ndata["feat"].shape[1]
model = GNN(
in_size,
hidden,
num_classes,
num_layers=args.layer,
model_name=args.model,
).to(device)
model = DistributedDataParallel(model)
num_of_samplers = len(load_core)
# create loader
drop_last, shuffle = True, True
if args.sampler.lower() == "neighbor":
sampler = NeighborSampler(
[int(fanout) for fanout in args.fan_out.split(",")],
prefetch_node_feats=["feat"],
prefetch_labels=["label"],
)
assert len(sampler.fanouts) == args.layer
elif args.sampler.lower() == "shadow":
sampler = ShaDowKHopSampler(
[10, 5],
output_device=device,
prefetch_node_feats=["feat"],
)
else:
raise NotImplementedError
train_dataloader = DataLoader(
g,
train_idx.to(device),
sampler,
device=device,
batch_size=b_size,
drop_last=drop_last,
shuffle=shuffle,
num_workers=num_of_samplers,
use_ddp=True,
)
# training loop
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
params = {
# training
"loader": train_dataloader,
"model": model,
"opt": opt,
# logging
"rank": rank,
"train_size": len(train_idx),
"batch_size": b_size,
"device": device,
"process": world_size,
}
PATH = "model.pt"
if counter[0] != 0:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint["model_state_dict"])
opt.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
for epoch in range(ep):
params["epoch"] = epoch
model.train()
params["load_core"] = load_core
params["comp_core"] = comp_core
loss = _train(**params)
if rank == 0:
print("loss:", loss)
dist.barrier()
EPOCH = counter[0]
LOSS = loss
if rank == 0:
torch.save(
{
"epoch": EPOCH,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"loss": LOSS,
},
PATH,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
choices=[
"ogbn-papers100M",
"ogbn-products",
"reddit",
"yelp",
"flickr",
],
)
parser.add_argument("--batch_size", type=int, default=1024 * 4)
parser.add_argument("--layer", type=int, default=3)
parser.add_argument("--fan_out", type=str, default="15,10,5")
parser.add_argument(
"--sampler",
type=str,
default="neighbor",
choices=["neighbor", "shadow"],
)
parser.add_argument(
"--model", type=str, default="sage", choices=["sage", "gcn"]
)
parser.add_argument("--hidden", type=int, default=128)
arguments = parser.parse_args()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
if arguments.dataset in ["reddit", "flickr", "yelp"]:
if arguments.dataset == "reddit":
dataset = RedditDataset()
elif arguments.dataset == "flickr":
dataset = FlickrDataset()
else:
dataset = YelpDataset()
g = dataset[0]
train_mask = g.ndata["train_mask"]
idx = []
for i in range(len(train_mask)):
if train_mask[i]:
idx.append(i)
dataset.train_idx = torch.tensor(idx)
else:
dataset = AsNodePredDataset(DglNodePropPredDataset(arguments.dataset))
g = dataset[0]
data = (dataset.num_classes, dataset.train_idx)
in_size = g.ndata["feat"].shape[1]
out_size = dataset.num_classes
hidden_size = int(arguments.hidden)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29501"
mp.set_start_method("fork", force=True)
runtime = ARGO(n_search=10, epoch=20, batch_size=arguments.batch_size)
runtime.run(train, args=(arguments, g, data))
"""
This is modified version of: https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-products/graphsage/main.py
"""
import argparse
import time
import dgl
import dgl.nn.pytorch as dglnn
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
class SAGE(nn.Module):
def __init__(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
# We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[: block.num_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
def inference(self, g, x, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and
layers.
"""
# During inference with sampling, multi-layer blocks are very inefficient because
# lots of computations in the first few layers are repeated.
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
for l, layer in enumerate(self.layers):
y = th.zeros(
g.num_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
).to(device)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader(
g,
th.arange(g.num_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(
dataloader, disable=None
):
block = blocks[0].int().to(device)
h = x[input_nodes]
h_dst = h[: block.num_dst_nodes()]
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
y[output_nodes] = h
x = y
return y
def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, nfeat, labels, val_nid, test_nid, device):
"""
Evaluate the model on the validation set specified by ``val_mask``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, nfeat, device)
model.train()
return (
compute_acc(pred[val_nid], labels[val_nid]),
compute_acc(pred[test_nid], labels[test_nid]),
pred,
)
def load_subtensor(nfeat, labels, seeds, input_nodes):
"""
Extracts features and labels for a set of nodes.
"""
batch_inputs = nfeat[input_nodes]
batch_labels = labels[seeds]
return batch_inputs, batch_labels
#### Entry point
def train(args, device, data):
# Unpack data
train_nid, val_nid, test_nid, in_feats, labels, n_classes, nfeat, g = data
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(",")]
)
dataloader = dgl.dataloading.DataLoader(
g,
train_nid,
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers,
)
# Define model and optimizer
model = SAGE(
in_feats,
args.num_hidden,
n_classes,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device)
loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
# Training loop
avg = 0
iter_tput = []
best_eval_acc = 0
best_test_acc = 0
with dataloader.enable_cpu_affinity():
for epoch in range(args.num_epochs):
tic = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
tic_step = time.time()
# copy block to gpu
blocks = [blk.int().to(device) for blk in blocks]
# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(
nfeat, labels, seeds, input_nodes
)
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iter_tput.append(len(seeds) / (time.time() - tic_step))
if step % args.log_every == 0 and step != 0:
acc = compute_acc(batch_pred, batch_labels)
print(
"Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}".format(
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
)
)
toc = time.time()
print("Epoch Time(s): {:.4f}".format(toc - tic))
avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0:
eval_acc, test_acc, pred = evaluate(
model, g, nfeat, labels, val_nid, test_nid, device
)
if args.save_pred:
np.savetxt(
args.save_pred + "%02d" % epoch,
pred.argmax(1).cpu().numpy(),
"%d",
)
print("Eval Acc {:.4f}".format(eval_acc))
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
best_test_acc = test_acc
print(
"Best Eval Acc {:.4f} Test Acc {:.4f}".format(
best_eval_acc, best_test_acc
)
)
print("Avg epoch time: {}".format(avg / args.num_epochs))
return best_test_acc
if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument(
"--gpu",
type=int,
default=0,
help="GPU device ID. Use -1 for CPU training",
)
argparser.add_argument("--num-epochs", type=int, default=20)
argparser.add_argument("--num-hidden", type=int, default=256)
argparser.add_argument("--num-layers", type=int, default=3)
argparser.add_argument("--fan-out", type=str, default="5,10,15")
argparser.add_argument("--batch-size", type=int, default=1000)
argparser.add_argument("--val-batch-size", type=int, default=10000)
argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument("--lr", type=float, default=0.003)
argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
choices=["ogbn-papers100M", "ogbn-products"],
)
argparser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of sampling processes. Use 0 for no extra process.",
)
argparser.add_argument("--save-pred", type=str, default="")
argparser.add_argument("--wd", type=float, default=0)
args = argparser.parse_args()
device = th.device("cpu")
# load ogbn-products data
data = DglNodePropPredDataset(args.dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0]
nfeat = graph.ndata.pop("feat").to(device)
labels = labels[:, 0].to(device)
in_feats = nfeat.shape[1]
n_classes = (labels.max() + 1).item()
# Create csr/coo/csc formats before launching sampling processes
# This avoids creating certain formats in each data loader process, which saves momory and CPU.
graph.create_formats_()
# Pack data
data = (
train_idx,
val_idx,
test_idx,
in_feats,
labels,
n_classes,
nfeat,
graph,
)
test_acc = train(args, device, data).cpu().numpy()
print("Test accuracy:", test_acc)
"""
This is a modified version of: https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-products/graphsage/main.py
This example shows how to enable ARGO to automatically instantiate multi-processing and adjust CPU core assignment to achieve better performance.
"""
import argparse
import ctypes
import os
import time
from multiprocessing import RawValue
import dgl
import dgl.nn.pytorch as dglnn
import numpy as np
import torch as th
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from argo import ARGO
from ogb.nodeproppred import DglNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel
avg_total = RawValue(ctypes.c_float, 0.0)
class SAGE(nn.Module):
def __init__(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
# We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[: block.num_dst_nodes()]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
def inference(self, g, x, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and
layers.
"""
# During inference with sampling, multi-layer blocks are very inefficient because
# lots of computations in the first few layers are repeated.
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
for l, layer in enumerate(self.layers):
y = th.zeros(
g.num_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
).to(device)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader(
g,
th.arange(g.num_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(
dataloader, disable=None
):
block = blocks[0].int().to(device)
h = x[input_nodes]
h_dst = h[: block.num_dst_nodes()]
h = layer(block, (h, h_dst))
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
y[output_nodes] = h
x = y
return y
def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, nfeat, labels, val_nid, test_nid, device):
"""
Evaluate the model on the validation set specified by ``val_mask``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.module.inference(g, nfeat, device)
model.train()
return (
compute_acc(pred[val_nid], labels[val_nid]),
compute_acc(pred[test_nid], labels[test_nid]),
pred,
)
def load_subtensor(nfeat, labels, seeds, input_nodes):
"""
Extracts features and labels for a set of nodes.
"""
batch_inputs = nfeat[input_nodes]
batch_labels = labels[seeds]
return batch_inputs, batch_labels
#### Entry point
def train(
args,
device,
data,
rank,
world_size,
comp_core,
load_core,
counter,
b_size,
ep,
):
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# Unpack data
train_nid, val_nid, test_nid, in_feats, labels, n_classes, nfeat, g = data
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(",")]
)
dataloader = dgl.dataloading.DataLoader(
g,
train_nid,
sampler,
batch_size=b_size,
shuffle=True,
drop_last=False,
num_workers=len(load_core),
use_ddp=True,
)
# Define model and optimizer
model = SAGE(
in_feats,
args.num_hidden,
n_classes,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device)
model = DistributedDataParallel(model)
loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
# Training loop
avg = 0
iter_tput = []
best_eval_acc = 0
best_test_acc = 0
PATH = "model.pt"
if counter[0] != 0:
checkpoint = th.load(PATH)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
with dataloader.enable_cpu_affinity(
loader_cores=load_core, compute_cores=comp_core
):
for epoch in range(ep):
tic = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
tic_step = time.time()
# copy block to gpu
blocks = [blk.int().to(device) for blk in blocks]
# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(
nfeat, labels, seeds, input_nodes
)
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iter_tput.append(len(seeds) / (time.time() - tic_step))
if step % args.log_every == 0 and step != 0:
acc = compute_acc(batch_pred, batch_labels)
print(
"Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}".format(
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
)
)
toc = time.time()
print("Epoch Time(s): {:.4f}".format(toc - tic))
if rank == 0:
global avg_total
avg_total.value += toc - tic
avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0:
eval_acc, test_acc, pred = evaluate(
model, g, nfeat, labels, val_nid, test_nid, device
)
if args.save_pred:
np.savetxt(
args.save_pred + "%02d" % epoch,
pred.argmax(1).cpu().numpy(),
"%d",
)
print("Eval Acc {:.4f}".format(eval_acc))
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
best_test_acc = test_acc
print(
"Best Eval Acc {:.4f} Test Acc {:.4f}".format(
best_eval_acc, best_test_acc
)
)
dist.barrier()
if rank == 0:
th.save(
{
"epoch": counter[0],
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
},
PATH,
)
if args.num_epochs == counter[0] + epoch + 1:
print(
"Avg epoch time: {}".format(avg_total.value / args.num_epochs)
)
print(
"Avg epoch time after auto-tuning: {}".format(avg / (epoch + 1))
)
return best_test_acc
if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument(
"--gpu",
type=int,
default=0,
help="GPU device ID. Use -1 for CPU training",
)
argparser.add_argument("--num-epochs", type=int, default=20)
argparser.add_argument("--num-hidden", type=int, default=256)
argparser.add_argument("--num-layers", type=int, default=3)
argparser.add_argument("--fan-out", type=str, default="5,10,15")
argparser.add_argument("--batch-size", type=int, default=1000)
argparser.add_argument("--val-batch-size", type=int, default=10000)
argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument("--lr", type=float, default=0.003)
argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
choices=["ogbn-papers100M", "ogbn-products"],
)
argparser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of sampling processes. Use 0 for no extra process.",
)
argparser.add_argument("--save-pred", type=str, default="")
argparser.add_argument("--wd", type=float, default=0)
args = argparser.parse_args()
device = th.device("cpu")
# load ogbn-products data
data = DglNodePropPredDataset(args.dataset)
splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0]
nfeat = graph.ndata.pop("feat").to(device)
labels = labels[:, 0].to(device)
in_feats = nfeat.shape[1]
n_classes = (labels.max() + 1).item()
# Create csr/coo/csc formats before launching sampling processes
# This avoids creating certain formats in each data loader process, which saves momory and CPU.
graph.create_formats_()
# Pack data
data = (
train_idx,
val_idx,
test_idx,
in_feats,
labels,
n_classes,
nfeat,
graph,
)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29501"
mp.set_start_method("fork", force=True)
runtime = ARGO(
n_search=15, epoch=args.num_epochs, batch_size=args.batch_size
) # initialization
runtime.run(train, args=(args, device, data)) # wrap the training function
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment