Unverified Commit 7d416086 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix multiple issues in distributed multi-GPU GraphSAGE example (#3870)



* fix distributed multi-GPU example device

* try Join

* update version requirement in README

* use model.join

* fix docs
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 16409ff8
...@@ -53,14 +53,15 @@ are the same as :ref:`mini-batch training <guide-minibatch>`. ...@@ -53,14 +53,15 @@ are the same as :ref:`mini-batch training <guide-minibatch>`.
# training loop # training loop
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
for step, blocks in enumerate(dataloader): with model.join():
batch_inputs, batch_labels = load_subtensor(g, blocks[0].srcdata[dgl.NID], for step, blocks in enumerate(dataloader):
blocks[-1].dstdata[dgl.NID]) batch_inputs, batch_labels = load_subtensor(g, blocks[0].srcdata[dgl.NID],
batch_pred = model(blocks, batch_inputs) blocks[-1].dstdata[dgl.NID])
loss = loss_fcn(batch_pred, batch_labels) batch_pred = model(blocks, batch_inputs)
optimizer.zero_grad() loss = loss_fcn(batch_pred, batch_labels)
loss.backward() optimizer.zero_grad()
optimizer.step() loss.backward()
optimizer.step()
When running the training script in a cluster of machines, DGL provides tools to copy data When running the training script in a cluster of machines, DGL provides tools to copy data
to the cluster's machines and launch the training job on all machines. to the cluster's machines and launch the training job on all machines.
......
...@@ -6,6 +6,8 @@ This is an example of training GraphSage in a distributed fashion. Before traini ...@@ -6,6 +6,8 @@ This is an example of training GraphSage in a distributed fashion. Before traini
sudo pip3 install ogb sudo pip3 install ogb
``` ```
**Requires PyTorch 1.10.0+ to work.**
To train GraphSage, it has five steps: To train GraphSage, it has five steps:
### Step 0: Setup a Distributed File System ### Step 0: Setup a Distributed File System
......
...@@ -20,6 +20,7 @@ import torch.nn.functional as F ...@@ -20,6 +20,7 @@ import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import socket
def load_subtensor(g, seeds, input_nodes, device, load_feat=True): def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
""" """
...@@ -155,41 +156,11 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): ...@@ -155,41 +156,11 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
model.train() model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]) return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid])
def pad_data(nids, device):
"""
In distributed traning scenario, we need to make sure that each worker has same number of
batches. Otherwise the synchronization(barrier) is called diffirent times, which results in
the worker with more batches hangs up.
This function pads the nids to the same size for all workers, by repeating the head ids till
the maximum size among all workers.
"""
import torch.distributed as dist
# NCCL backend only supports GPU tensors, thus here we need to allocate it to gpu
num_nodes = th.tensor(nids.numel()).to(device)
dist.all_reduce(num_nodes, dist.ReduceOp.MAX)
max_num_nodes = int(num_nodes)
nids_length = nids.shape[0]
if max_num_nodes > nids_length:
pad_size = max_num_nodes % nids_length
repeat_size = max_num_nodes // nids_length
new_nids = th.cat([nids for _ in range(repeat_size)] + [nids[:pad_size]], axis=0)
print("Pad nids from {} to {}".format(nids_length, max_num_nodes))
else:
new_nids = nids
assert new_nids.shape[0] == max_num_nodes
return new_nids
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, in_feats, n_classes, g = data train_nid, val_nid, test_nid, in_feats, n_classes, g = data
shuffle = True shuffle = True
if args.pad_data:
train_nid = pad_data(train_nid, device)
# Current pipeline doesn't support duplicate node id within the same batch
# Therefore turn off shuffling to avoid potential duplicate node id within the same batch
shuffle = False
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors, device) dgl.distributed.sample_neighbors, device)
...@@ -209,8 +180,7 @@ def run(args, device, data): ...@@ -209,8 +180,7 @@ def run(args, device, data):
if args.num_gpus == -1: if args.num_gpus == -1:
model = th.nn.parallel.DistributedDataParallel(model) model = th.nn.parallel.DistributedDataParallel(model)
else: else:
dev_id = g.rank() % args.num_gpus model = th.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device)
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device) loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
...@@ -233,43 +203,46 @@ def run(args, device, data): ...@@ -233,43 +203,46 @@ def run(args, device, data):
# Loop over the dataloader to sample the computation dependency graph as a list of # Loop over the dataloader to sample the computation dependency graph as a list of
# blocks. # blocks.
step_time = [] step_time = []
for step, blocks in enumerate(dataloader):
tic_step = time.time() with model.join():
sample_time += tic_step - start for step, blocks in enumerate(dataloader):
tic_step = time.time()
# The nodes for input lies at the LHS side of the first block. sample_time += tic_step - start
# The nodes for output lies at the RHS side of the last block.
batch_inputs = blocks[0].srcdata['features'] # The nodes for input lies at the LHS side of the first block.
batch_labels = blocks[-1].dstdata['labels'] # The nodes for output lies at the RHS side of the last block.
batch_labels = batch_labels.long() batch_inputs = blocks[0].srcdata['features']
batch_labels = blocks[-1].dstdata['labels']
num_seeds += len(blocks[-1].dstdata[dgl.NID]) batch_labels = batch_labels.long()
num_inputs += len(blocks[0].srcdata[dgl.NID])
blocks = [block.to(device) for block in blocks] num_seeds += len(blocks[-1].dstdata[dgl.NID])
batch_labels = batch_labels.to(device) num_inputs += len(blocks[0].srcdata[dgl.NID])
# Compute loss and prediction blocks = [block.to(device) for block in blocks]
start = time.time() batch_labels = batch_labels.to(device)
batch_pred = model(blocks, batch_inputs) # Compute loss and prediction
loss = loss_fcn(batch_pred, batch_labels) start = time.time()
forward_end = time.time() #print(g.rank(), blocks[0].device, model.module.layers[0].fc_neigh.weight.device, dev_id)
optimizer.zero_grad() batch_pred = model(blocks, batch_inputs)
loss.backward() loss = loss_fcn(batch_pred, batch_labels)
compute_end = time.time() forward_end = time.time()
forward_time += forward_end - start optimizer.zero_grad()
backward_time += compute_end - forward_end loss.backward()
compute_end = time.time()
optimizer.step() forward_time += forward_end - start
update_time += time.time() - compute_end backward_time += compute_end - forward_end
step_t = time.time() - tic_step optimizer.step()
step_time.append(step_t) update_time += time.time() - compute_end
iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
if step % args.log_every == 0: step_t = time.time() - tic_step
acc = compute_acc(batch_pred, batch_labels) step_time.append(step_t)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s'.format( if step % args.log_every == 0:
g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:]))) acc = compute_acc(batch_pred, batch_labels)
start = time.time() gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s'.format(
g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:])))
start = time.time()
toc = time.time() toc = time.time()
print('Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format( print('Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
...@@ -285,11 +258,14 @@ def run(args, device, data): ...@@ -285,11 +258,14 @@ def run(args, device, data):
time.time() - start)) time.time() - start))
def main(args): def main(args):
print(socket.gethostname(), 'Initializing DGL dist')
dgl.distributed.initialize(args.ip_config) dgl.distributed.initialize(args.ip_config)
if not args.standalone: if not args.standalone:
print(socket.gethostname(), 'Initializing DGL process group')
th.distributed.init_process_group(backend=args.backend) th.distributed.init_process_group(backend=args.backend)
print(socket.gethostname(), 'Initializing DistGraph')
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print(socket.gethostname(), 'rank:', g.rank())
pb = g.get_partition_book() pb = g.get_partition_book()
if 'trainer_id' in g.ndata: if 'trainer_id' in g.ndata:
...@@ -311,7 +287,8 @@ def main(args): ...@@ -311,7 +287,8 @@ def main(args):
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device('cpu')
else: else:
device = th.device('cuda:'+str(args.local_rank)) dev_id = g.rank() % args.num_gpus
device = th.device('cuda:'+str(dev_id))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())] labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
print('#labels:', n_classes) print('#labels:', n_classes)
......
...@@ -299,23 +299,24 @@ The training loop for distributed training is also exactly the same as the singl ...@@ -299,23 +299,24 @@ The training loop for distributed training is also exactly the same as the singl
for epoch in range(10): for epoch in range(10):
# Loop over the dataloader to sample mini-batches. # Loop over the dataloader to sample mini-batches.
losses = [] losses = []
for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader): with model.join():
# Load the input features as well as output labels for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader):
batch_inputs = g.ndata['feat'][input_nodes] # Load the input features as well as output labels
batch_labels = g.ndata['labels'][seeds] batch_inputs = g.ndata['feat'][input_nodes]
batch_labels = g.ndata['labels'][seeds]
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs) # Compute loss and prediction
loss = loss_fcn(batch_pred, batch_labels) batch_pred = model(blocks, batch_inputs)
optimizer.zero_grad() loss = loss_fcn(batch_pred, batch_labels)
loss.backward() optimizer.zero_grad()
losses.append(loss.detach().cpu().numpy()) loss.backward()
optimizer.step() losses.append(loss.detach().cpu().numpy())
optimizer.step()
# validation # validation
predictions = [] predictions = []
labels = [] labels = []
with th.no_grad(): with th.no_grad(), model.join():
for step, (input_nodes, seeds, blocks) in enumerate(valid_dataloader): for step, (input_nodes, seeds, blocks) in enumerate(valid_dataloader):
inputs = g.ndata['feat'][input_nodes] inputs = g.ndata['feat'][input_nodes]
labels.append(g.ndata['labels'][seeds].numpy()) labels.append(g.ndata['labels'][seeds].numpy())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment