Unverified Commit 27a6eb56 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Examples] Update graphsage multi-gpu example to use mutliple GPUs for...


[Examples] Update graphsage multi-gpu example to use mutliple GPUs for validation and testing. (#3827)

* Update graphsage multi-gpu example to use mutliple GPUs for validation and
testing.

* Remove argmax

* Fix rebase error

* Add more documentation to example and simplify

* Switch to name shared memory

* Add comment about how training is distributed

* Restore iteration count

* fix munmap error reporting for better error messages
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 11e910b6
...@@ -6,11 +6,45 @@ import torch.distributed.optim ...@@ -6,11 +6,45 @@ import torch.distributed.optim
import torchmetrics.functional as MF import torchmetrics.functional as MF
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
from dgl.utils import pin_memory_inplace, unpin_memory_inplace, \
gather_pinned_tensor_rows, create_shared_mem_array, get_shared_mem_array
import time import time
import numpy as np import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm import tqdm
def shared_tensor(*shape, device, name, dtype=torch.float32):
""" Create a tensor in shared memroy, pinned in each process's CUDA
context.
Parameters
----------
shape : int...
A sequence of integers describing the shape of the new tensor.
device : context
The device of the result tensor.
name : string
The name of the shared allocation.
dtype : dtype, optional
The datatype of the allocation. Default: torch.float32
Returns
-------
Tensor :
The shared tensor.
"""
rank = dist.get_rank()
if rank == 0:
y = create_shared_mem_array(
name, shape, dtype)
dist.barrier()
if rank != 0:
y = get_shared_mem_array(name, shape, dtype)
pin_memory_inplace(y)
return y
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
super().__init__() super().__init__()
...@@ -22,39 +56,75 @@ class SAGE(nn.Module): ...@@ -22,39 +56,75 @@ class SAGE(nn.Module):
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
def _forward_layer(self, l, block, x):
h = self.layers[l](block, x)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)): for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h) h = self._forward_layer(l, blocks[l], h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None): def inference(self, g, device, batch_size):
# The difference between this inference function and the one in the official """
# example is that the intermediate results can also benefit from prefetching. Perform inference in layer-major order rather than batch-major order.
That is, infer the first layer for the entire graph, and store the
intermediate values h_0, before infering the second layer to generate
h_1. This is done for two reasons: 1) it limits the effect of node
degree on the amount of memory used as it only proccesses 1-hop
neighbors at a time, and 2) it reduces the total amount of computation
required as each node is only processed once per layer.
Parameters
----------
g : DGLGraph
The graph to perform inference on.
device : context
The device this process should use for inference
batch_size : int
The number of items to collect in a batch.
Returns
-------
tensor
The predictions for all nodes in the graph.
"""
g.ndata['h'] = g.ndata['feat'] g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes(), device=device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers, batch_size=batch_size, shuffle=False, drop_last=False,
persistent_workers=(num_workers > 0)) num_workers=0, use_ddp=True, use_uva=True)
if buffer_device is None:
buffer_device = device
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.zeros( # in order to prevent running out of GPU memory, we allocate a
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes, # shared output tensor 'y' in host memory, pin it to allow UVA
device=buffer_device) # access from each GPU during forward propagation.
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): y = shared_tensor(
g.num_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device='cpu', name='layer_{}_output'.format(l))
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader) \
if dist.get_rank() == 0 else dataloader:
x = blocks[0].srcdata['h'] x = blocks[0].srcdata['h']
h = layer(blocks[0], x) h = self._forward_layer(l, blocks[0], x)
if l != len(self.layers) - 1: y[output_nodes] = h.to(y.device)
h = F.relu(h) # make sure all GPUs are done writing to 'y'
h = self.dropout(h) dist.barrier()
y[output_nodes] = h.to(buffer_device) if l > 0:
g.ndata['h'] = y unpin_memory_inplace(g.ndata['h'])
if l + 1 < len(self.layers):
# assign the output features of this layer as the new input
# features for the next layer
g.ndata['h'] = y
else:
# remove the intermediate data from the graph
g.ndata.pop('h')
return y return y
...@@ -68,18 +138,25 @@ def train(rank, world_size, graph, num_classes, split_idx): ...@@ -68,18 +138,25 @@ def train(rank, world_size, graph, num_classes, split_idx):
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
# move ids to GPU
train_idx = train_idx.to('cuda') train_idx = train_idx.to('cuda')
valid_idx = valid_idx.to('cuda') valid_idx = valid_idx.to('cuda')
test_idx = test_idx.to('cuda')
# For training, each process/GPU will get a subset of the
# train_idx/valid_idx, and generate mini-batches indepednetly. This allows
# the only communication neccessary in training to be the all-reduce for
# the gradients performed by the DDP wrapper (created above).
sampler = dgl.dataloading.NeighborSampler( sampler = dgl.dataloading.NeighborSampler(
[15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label']) [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, graph, train_idx, sampler,
device='cuda', batch_size=1000, shuffle=True, drop_last=False, device='cuda', batch_size=1024, shuffle=True, drop_last=False,
num_workers=0, use_ddp=True, use_uva=True) num_workers=0, use_ddp=True, use_uva=True)
valid_dataloader = dgl.dataloading.DataLoader( valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True, graph, valid_idx, sampler, device='cuda', batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=True) drop_last=False, num_workers=0, use_ddp=True,
use_uva=True)
durations = [] durations = []
for _ in range(10): for _ in range(10):
...@@ -93,33 +170,38 @@ def train(rank, world_size, graph, num_classes, split_idx): ...@@ -93,33 +170,38 @@ def train(rank, world_size, graph, num_classes, split_idx):
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
opt.step() opt.step()
if it % 20 == 0: if it % 20 == 0 and rank == 0:
acc = MF.accuracy(y_hat, y) acc = MF.accuracy(y_hat, y)
mem = torch.cuda.max_memory_allocated() / 1000000 mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
tt = time.time() tt = time.time()
if rank == 0: if rank == 0:
print(tt - t0) print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
with torch.no_grad(): with torch.no_grad():
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata['feat']
ys.append(blocks[-1].dstdata['label']) ys.append(blocks[-1].dstdata['label'])
y_hats.append(model.module(blocks, x)) y_hats.append(model.module(blocks, x))
acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys)) acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys)) / world_size
dist.reduce(acc, 0)
if rank == 0:
print('Validation acc:', acc.item()) print('Validation acc:', acc.item())
dist.barrier() dist.barrier()
if rank == 0: if rank == 0:
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.module.inference(graph, 'cuda', 1000, 12, graph.device) # since we do 1-layer at a time, use a very large batch size
acc = MF.accuracy(pred.to(graph.device), graph.ndata['label']) pred = model.module.inference(graph, device='cuda', batch_size=2**16)
if rank == 0:
acc = MF.accuracy(pred[test_idx], graph.ndata['label'][test_idx])
print('Test acc:', acc.item()) print('Test acc:', acc.item())
if __name__ == '__main__': if __name__ == '__main__':
...@@ -129,7 +211,8 @@ if __name__ == '__main__': ...@@ -129,7 +211,8 @@ if __name__ == '__main__':
graph.create_formats_() # must be called before mp.spawn(). graph.create_formats_() # must be called before mp.spawn().
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
num_classes = dataset.num_classes num_classes = dataset.num_classes
n_procs = 4 # use all available GPUs
n_procs = torch.cuda.device_count()
# Tested with mp.spawn and fork. Both worked and got 4s per epoch with 4 GPUs # Tested with mp.spawn and fork. Both worked and got 4s per epoch with 4 GPUs
# and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples. # and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples.
......
...@@ -55,17 +55,23 @@ SharedMemory::SharedMemory(const std::string &name) { ...@@ -55,17 +55,23 @@ SharedMemory::SharedMemory(const std::string &name) {
SharedMemory::~SharedMemory() { SharedMemory::~SharedMemory() {
#ifndef _WIN32 #ifndef _WIN32
CHECK(munmap(ptr_, size_) != -1) << strerror(errno); if (ptr_ && size_ != 0)
close(fd_); CHECK(munmap(ptr_, size_) != -1) << strerror(errno);
if (fd_ != -1)
close(fd_);
if (own_) { if (own_) {
// LOG(INFO) << "remove " << name << " for shared memory"; // LOG(INFO) << "remove " << name << " for shared memory";
shm_unlink(name.c_str()); if (name != "") {
shm_unlink(name.c_str());
// The resource has been deleted. We don't need to keep track of it any more. // The resource has been deleted. We don't need to keep track of it any more.
DeleteResource(name); DeleteResource(name);
}
} }
#else #else
CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError(); if (ptr_)
CloseHandle(handle_); CHECK(UnmapViewOfFile(ptr_)) << "Win32 Error: " << GetLastError();
if (handle_)
CloseHandle(handle_);
// Windows do not need a separate shm_unlink step. // Windows do not need a separate shm_unlink step.
#endif // _WIN32 #endif // _WIN32
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment