Unverified Commit 02025989 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Examples] refine dist train example (#5763)

parent 041f78ba
...@@ -12,18 +12,26 @@ import torch.optim as optim ...@@ -12,18 +12,26 @@ import torch.optim as optim
import tqdm import tqdm
def load_subtensor(g, seeds, input_nodes, device, load_feat=True): class DistSAGE(nn.Module):
""" """
Copys features and labels of a set of nodes onto GPU. SAGE model for distributed train and evaluation.
Parameters
----------
in_feats : int
Feature dimension.
n_hidden : int
Hidden layer dimension.
n_classes : int
Number of classes.
n_layers : int
Number of layers.
activation : callable
Activation function.
dropout : float
Dropout value.
""" """
batch_inputs = (
g.ndata["features"][input_nodes].to(device) if load_feat else None
)
batch_labels = g.ndata["labels"][seeds].to(device)
return batch_inputs, batch_labels
class DistSAGE(nn.Module):
def __init__( def __init__(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
): ):
...@@ -40,6 +48,16 @@ class DistSAGE(nn.Module): ...@@ -40,6 +48,16 @@ class DistSAGE(nn.Module):
self.activation = activation self.activation = activation
def forward(self, blocks, x): def forward(self, blocks, x):
"""
Forward function.
Parameters
----------
blocks : List[DGLBlock]
Sampled blocks.
x : DistTensor
Feature data.
"""
h = x h = x
for i, (layer, block) in enumerate(zip(self.layers, blocks)): for i, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h) h = layer(block, h)
...@@ -50,41 +68,46 @@ class DistSAGE(nn.Module): ...@@ -50,41 +68,46 @@ class DistSAGE(nn.Module):
def inference(self, g, x, batch_size, device): def inference(self, g, x, batch_size, device):
""" """
Inference with the GraphSAGE model on full neighbors (i.e. without Distributed layer-wise inference with the GraphSAGE model on full
neighbor sampling). neighbors.
g : the entire graph. Parameters
x : the input of entire node set. ----------
g : DistGraph
Distributed layer-wise inference. Input Graph for inference.
x : DistTensor
Node feature data of input graph.
Returns
-------
DistTensor
Inference results.
""" """
# During inference with sampling, multi-layer blocks are very # Split nodes to each trainer.
# inefficient because lots of computations in the first few layers
# are repeated. Therefore, we compute the representation of all nodes
# layer by layer. The nodes on each layer are of course splitted in
# batches.
# TODO: can we standardize this?
nodes = dgl.distributed.node_split( nodes = dgl.distributed.node_split(
np.arange(g.num_nodes()), np.arange(g.num_nodes()),
g.get_partition_book(), g.get_partition_book(),
force_even=True, force_even=True,
) )
y = dgl.distributed.DistTensor(
(g.num_nodes(), self.n_hidden),
th.float32,
"h",
persistent=True,
)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
# Create DistTensor to save forward results.
if i == len(self.layers) - 1: if i == len(self.layers) - 1:
y = dgl.distributed.DistTensor( out_dim = self.n_classes
(g.num_nodes(), self.n_classes), name = "h_last"
th.float32, else:
"h_last", out_dim = self.n_hidden
persistent=True, name = "h"
) y = dgl.distributed.DistTensor(
print(f"|V|={g.num_nodes()}, eval batch size: {batch_size}") (g.num_nodes(), out_dim),
th.float32,
name,
persistent=True,
)
print(f"|V|={g.num_nodes()}, inference batch size: {batch_size}")
# `-1` indicates all inbound edges will be inlcuded, namely, full
# neighbor sampling.
sampler = dgl.dataloading.NeighborSampler([-1]) sampler = dgl.dataloading.NeighborSampler([-1])
dataloader = dgl.dataloading.DistNodeDataLoader( dataloader = dgl.dataloading.DistNodeDataLoader(
g, g,
...@@ -103,17 +126,30 @@ class DistSAGE(nn.Module): ...@@ -103,17 +126,30 @@ class DistSAGE(nn.Module):
if i != len(self.layers) - 1: if i != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
# Copy back to CPU as DistTensor requires data reside on CPU.
y[output_nodes] = h.cpu() y[output_nodes] = h.cpu()
x = y x = y
# Synchronize trainers.
g.barrier() g.barrier()
return y return x
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
Parameters
----------
pred : torch.Tensor
Predicted labels.
labels : torch.Tensor
Ground-truth labels.
Returns
-------
float
Accuracy.
""" """
labels = labels.long() labels = labels.long()
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
...@@ -121,13 +157,33 @@ def compute_acc(pred, labels): ...@@ -121,13 +157,33 @@ def compute_acc(pred, labels):
def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
""" """
Evaluate the model on the validation set specified by ``val_nid``. Evaluate the model on the validation and test set.
g : The entire graph.
inputs : The features of all the nodes. Parameters
labels : The labels of all the nodes. ----------
val_nid : the node Ids for validation. model : DistSAGE
batch_size : Number of nodes to compute at the same time. The model to be evaluated.
device : The GPU device to evaluate on. g : DistGraph
The entire graph.
inputs : DistTensor
The feature data of all the nodes.
labels : DistTensor
The labels of all the nodes.
val_nid : torch.Tensor
The node IDs for validation.
test_nid : torch.Tensor
The node IDs for test.
batch_size : int
Batch size for evaluation.
device : torch.Device
The target device to evaluate on.
Returns
-------
float
Validation accuracy.
float
Test accuracy.
""" """
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
...@@ -139,6 +195,19 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): ...@@ -139,6 +195,19 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
def run(args, device, data): def run(args, device, data):
"""
Train and evaluate DistSAGE.
Parameters
----------
args : argparse.Args
Arguments for train and evaluate.
device : torch.Device
Target device for train and evaluate.
data : Packed Data
Packed data includes train/val/test IDs, feature dimension,
number of classes, graph.
"""
train_nid, val_nid, test_nid, in_feats, n_classes, g = data train_nid, val_nid, test_nid, in_feats, n_classes, g = data
sampler = dgl.dataloading.NeighborSampler( sampler = dgl.dataloading.NeighborSampler(
[int(fanout) for fanout in args.fan_out.split(",")] [int(fanout) for fanout in args.fan_out.split(",")]
...@@ -178,6 +247,7 @@ def run(args, device, data): ...@@ -178,6 +247,7 @@ def run(args, device, data):
for _ in range(args.num_epochs): for _ in range(args.num_epochs):
epoch += 1 epoch += 1
tic = time.time() tic = time.time()
# Various time statistics.
sample_time = 0 sample_time = 0
forward_time = 0 forward_time = 0
backward_time = 0 backward_time = 0
...@@ -185,18 +255,15 @@ def run(args, device, data): ...@@ -185,18 +255,15 @@ def run(args, device, data):
num_seeds = 0 num_seeds = 0
num_inputs = 0 num_inputs = 0
start = time.time() start = time.time()
# Loop over the dataloader to sample the computation dependency graph
# as a list of blocks.
step_time = [] step_time = []
with model.join(): with model.join():
for step, (input_nodes, seeds, blocks) in enumerate(dataloader): for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
tic_step = time.time() tic_step = time.time()
sample_time += tic_step - start sample_time += tic_step - start
batch_inputs, batch_labels = load_subtensor( # Slice feature and label.
g, seeds, input_nodes, "cpu" batch_inputs = g.ndata["features"][input_nodes]
) batch_labels = g.ndata["labels"][seeds].long()
batch_labels = batch_labels.long()
num_seeds += len(blocks[-1].dstdata[dgl.NID]) num_seeds += len(blocks[-1].dstdata[dgl.NID])
num_inputs += len(blocks[0].srcdata[dgl.NID]) num_inputs += len(blocks[0].srcdata[dgl.NID])
# Move to target device. # Move to target device.
...@@ -227,36 +294,23 @@ def run(args, device, data): ...@@ -227,36 +294,23 @@ def run(args, device, data):
if th.cuda.is_available() if th.cuda.is_available()
else 0 else 0
) )
sample_speed = np.mean(iter_tput[-args.log_every :])
mean_step_time = np.mean(step_time[-args.log_every :])
print( print(
"Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | " f"Part {g.rank()} | Epoch {epoch:05d} | Step {step:05d}"
"Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU " f" | Loss {loss.item():.4f} | Train Acc {acc.item():.4f}"
"{:.1f} MB | time {:.3f} s".format( f" | Speed (samples/sec) {sample_speed:.4f}"
g.rank(), f" | GPU {gpu_mem_alloc:.1f} MB | "
epoch, f"Mean step time {mean_step_time:.3f} s"
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
gpu_mem_alloc,
np.mean(step_time[-args.log_every :]),
)
) )
start = time.time() start = time.time()
toc = time.time() toc = time.time()
print( print(
"Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, " f"Part {g.rank()}, Epoch Time(s): {toc - tic:.4f}, "
"forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, " f"sample+data_copy: {sample_time:.4f}, forward: {forward_time:.4f},"
"#inputs: {}".format( f" backward: {backward_time:.4f}, update: {update_time:.4f}, "
g.rank(), f"#seeds: {num_seeds}, #inputs: {num_inputs}"
toc - tic,
sample_time,
forward_time,
backward_time,
update_time,
num_seeds,
num_inputs,
)
) )
epoch_time.append(toc - tic) epoch_time.append(toc - tic)
...@@ -273,23 +327,27 @@ def run(args, device, data): ...@@ -273,23 +327,27 @@ def run(args, device, data):
device, device,
) )
print( print(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format( f"Part {g.rank()}, Val Acc {val_acc:.4f}, "
g.rank(), val_acc, test_acc, time.time() - start f"Test Acc {test_acc:.4f}, time: {time.time() - start:.4f}"
)
) )
return np.mean(epoch_time[-int(args.num_epochs * 0.8) :]), test_acc return np.mean(epoch_time[-int(args.num_epochs * 0.8) :]), test_acc
def main(args): def main(args):
print(socket.gethostname(), "Initializing DistDGL.") """
Main function.
"""
host_name = socket.gethostname()
print(f"{host_name}: Initializing DistDGL.")
dgl.distributed.initialize(args.ip_config, net_type=args.net_type) dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
print(socket.gethostname(), "Initializing PyTorch process group.") print(f"{host_name}: Initializing PyTorch process group.")
th.distributed.init_process_group(backend=args.backend) th.distributed.init_process_group(backend=args.backend)
print(socket.gethostname(), "Initializing DistGraph.") print(f"{host_name}: Initializing DistGraph.")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print(socket.gethostname(), "rank:", g.rank()) print(f"Rank of {host_name}: {g.rank()}")
# Split train/val/test IDs for each trainer.
pb = g.get_partition_book() pb = g.get_partition_book()
if "trainer_id" in g.ndata: if "trainer_id" in g.ndata:
train_nid = dgl.distributed.node_split( train_nid = dgl.distributed.node_split(
...@@ -321,17 +379,13 @@ def main(args): ...@@ -321,17 +379,13 @@ def main(args):
g.ndata["test_mask"], pb, force_even=True g.ndata["test_mask"], pb, force_even=True
) )
local_nid = pb.partid2nids(pb.partid).detach().numpy() local_nid = pb.partid2nids(pb.partid).detach().numpy()
num_train_local = len(np.intersect1d(train_nid.numpy(), local_nid))
num_val_local = len(np.intersect1d(val_nid.numpy(), local_nid))
num_test_local = len(np.intersect1d(test_nid.numpy(), local_nid))
print( print(
"part {}, train: {} (local: {}), val: {} (local: {}), test: {} " f"part {g.rank()}, train: {len(train_nid)} (local: {num_train_local}), "
"(local: {})".format( f"val: {len(val_nid)} (local: {num_val_local}), "
g.rank(), f"test: {len(test_nid)} (local: {num_test_local})"
len(train_nid),
len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid),
len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid),
len(np.intersect1d(test_nid.numpy(), local_nid)),
)
) )
del local_nid del local_nid
if args.num_gpus == 0: if args.num_gpus == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment