import dgl import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.multiprocessing as mp from torch.utils.data import DataLoader import dgl.function as fn import dgl.nn.pytorch as dglnn import time import argparse from _thread import start_new_thread from functools import wraps from dgl.data import RedditDataset from torch.nn.parallel import DistributedDataParallel import tqdm import traceback #### Neighbor sampler class NeighborSampler(object): def __init__(self, g, fanouts): self.g = g self.fanouts = fanouts def sample_blocks(self, seeds): seeds = th.LongTensor(np.asarray(seeds)) blocks = [] for fanout in self.fanouts: # For each seed node, sample ``fanout`` neighbors. frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True) # Then we compact the frontier into a bipartite graph for message passing. block = dgl.to_block(frontier, seeds) # Obtain the seed nodes for next layer. seeds = block.srcdata[dgl.NID] blocks.insert(0, block) return blocks class SAGE(nn.Module): def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout): super().__init__() self.n_layers = n_layers self.n_hidden = n_hidden self.n_classes = n_classes self.layers = nn.ModuleList() self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) for i in range(1, n_layers - 1): self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.dropout = nn.Dropout(dropout) self.activation = activation def forward(self, blocks, x): h = x for l, (layer, block) in enumerate(zip(self.layers, blocks)): # We need to first copy the representation of nodes on the RHS from the # appropriate nodes on the LHS. # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst # would be (num_nodes_RHS, D) h_dst = h[:block.number_of_dst_nodes()] # Then we compute the updated representation on the RHS. # The shape of h now becomes (num_nodes_RHS, D) h = layer(block, (h, h_dst)) if l != len(self.layers) - 1: h = self.activation(h) h = self.dropout(h) return h def inference(self, g, x, batch_size, device): """ Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). g : the entire graph. x : the input of entire node set. The inference code is written in a fashion that it could handle any number of nodes and layers. """ # During inference with sampling, multi-layer blocks are very inefficient because # lots of computations in the first few layers are repeated. # Therefore, we compute the representation of all nodes layer by layer. The nodes # on each layer are of course splitted in batches. # TODO: can we standardize this? nodes = th.arange(g.number_of_nodes()) for l, layer in enumerate(self.layers): y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) for start in tqdm.trange(0, len(nodes), batch_size): end = start + batch_size batch_nodes = nodes[start:end] block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) input_nodes = block.srcdata[dgl.NID] h = x[input_nodes].to(device) h_dst = h[:block.number_of_dst_nodes()] h = layer(block, (h, h_dst)) if l != len(self.layers) - 1: h = self.activation(h) h = self.dropout(h) y[start:end] = h.cpu() x = y return y #### Miscellaneous functions # According to https://github.com/pytorch/pytorch/issues/17199, this decorator # is necessary to make fork() and openmp work together. # # TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need # to standardize worker process creation since our operators are implemented with # OpenMP. def thread_wrapped_func(func): """ Wraps a process entry point to make it work with OpenMP. """ @wraps(func) def decorated_function(*args, **kwargs): queue = mp.Queue() def _queue_result(): exception, trace, res = None, None, None try: res = func(*args, **kwargs) except Exception as e: exception = e trace = traceback.format_exc() queue.put((res, exception, trace)) start_new_thread(_queue_result, ()) result, exception, trace = queue.get() if exception is None: return result else: assert isinstance(exception, Exception) raise exception.__class__(trace) return decorated_function def prepare_mp(g): """ Explicitly materialize the CSR, CSC and COO representation of the given graph so that they could be shared via copy-on-write to sampler workers and GPU trainers. This is a workaround before full shared memory support on heterogeneous graphs. """ g.in_degree(0) g.out_degree(0) g.find_edges([0]) def compute_acc(pred, labels): """ Compute the accuracy of prediction given the labels. """ return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) def evaluate(model, g, inputs, labels, val_mask, batch_size, device): """ Evaluate the model on the validation set specified by ``val_mask``. g : The entire graph. inputs : The features of all the nodes. labels : The labels of all the nodes. val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for. batch_size : Number of nodes to compute at the same time. device : The GPU device to evaluate on. """ model.eval() with th.no_grad(): pred = model.inference(g, inputs, batch_size, device) model.train() return compute_acc(pred[val_mask], labels[val_mask]) def load_subtensor(g, labels, seeds, input_nodes, dev_id): """ Copys features and labels of a set of nodes onto GPU. """ batch_inputs = g.ndata['features'][input_nodes].to(dev_id) batch_labels = labels[seeds].to(dev_id) return batch_inputs, batch_labels #### Entry point def run(proc_id, n_gpus, args, devices, data): # Start up distributed training, if enabled. dev_id = devices[proc_id] if n_gpus > 1: dist_init_method = 'tcp://{master_ip}:{master_port}'.format( master_ip='127.0.0.1', master_port='12345') world_size = n_gpus th.distributed.init_process_group(backend="nccl", init_method=dist_init_method, world_size=world_size, rank=proc_id) th.cuda.set_device(dev_id) # Unpack data train_mask, val_mask, in_feats, labels, n_classes, g = data train_nid = th.LongTensor(np.nonzero(train_mask)[0]) val_nid = th.LongTensor(np.nonzero(val_mask)[0]) train_mask = th.BoolTensor(train_mask) val_mask = th.BoolTensor(val_mask) # Split train_nid train_nid = th.split(train_nid, len(train_nid) // n_gpus)[proc_id] # Create sampler sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')]) # Create PyTorch DataLoader for constructing blocks dataloader = DataLoader( dataset=train_nid.numpy(), batch_size=args.batch_size, collate_fn=sampler.sample_blocks, shuffle=True, drop_last=False, num_workers=args.num_workers) # Define model and optimizer model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) model = model.to(dev_id) if n_gpus > 1: model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) loss_fcn = nn.CrossEntropyLoss() loss_fcn = loss_fcn.to(dev_id) optimizer = optim.Adam(model.parameters(), lr=args.lr) # Training loop avg = 0 iter_tput = [] for epoch in range(args.num_epochs): tic = time.time() # Loop over the dataloader to sample the computation dependency graph as a list of # blocks. for step, blocks in enumerate(dataloader): if proc_id == 0: tic_step = time.time() # The nodes for input lies at the LHS side of the first block. # The nodes for output lies at the RHS side of the last block. input_nodes = blocks[0].srcdata[dgl.NID] seeds = blocks[-1].dstdata[dgl.NID] # Load the input features as well as output labels batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, dev_id) # Compute loss and prediction batch_pred = model(blocks, batch_inputs) loss = loss_fcn(batch_pred, batch_labels) optimizer.zero_grad() loss.backward() if n_gpus > 1: for param in model.parameters(): if param.requires_grad and param.grad is not None: th.distributed.all_reduce(param.grad.data, op=th.distributed.ReduceOp.SUM) param.grad.data /= n_gpus optimizer.step() if proc_id == 0: iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step)) if step % args.log_every == 0 and proc_id == 0: acc = compute_acc(batch_pred, batch_labels) print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format( epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), th.cuda.max_memory_allocated() / 1000000)) if n_gpus > 1: th.distributed.barrier() toc = time.time() if proc_id == 0: print('Epoch Time(s): {:.4f}'.format(toc - tic)) if epoch >= 5: avg += toc - tic if epoch % args.eval_every == 0 and epoch != 0: if n_gpus == 1: eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, devices[0]) else: eval_acc = evaluate(model.module, g, g.ndata['features'], labels, val_mask, args.batch_size, devices[0]) print('Eval Acc {:.4f}'.format(eval_acc)) if n_gpus > 1: th.distributed.barrier() if proc_id == 0: print('Avg epoch time: {}'.format(avg / (epoch - 4))) if __name__ == '__main__': argparser = argparse.ArgumentParser("multi-gpu training") argparser.add_argument('--gpu', type=str, default='0', help="Comma separated list of GPU device IDs.") argparser.add_argument('--num-epochs', type=int, default=20) argparser.add_argument('--num-hidden', type=int, default=16) argparser.add_argument('--num-layers', type=int, default=2) argparser.add_argument('--fan-out', type=str, default='10,25') argparser.add_argument('--batch-size', type=int, default=1000) argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument('--eval-every', type=int, default=5) argparser.add_argument('--lr', type=float, default=0.003) argparser.add_argument('--dropout', type=float, default=0.5) argparser.add_argument('--num-workers', type=int, default=0, help="Number of sampling processes. Use 0 for no extra process.") args = argparser.parse_args() devices = list(map(int, args.gpu.split(','))) n_gpus = len(devices) # load reddit data data = RedditDataset(self_loop=True) train_mask = data.train_mask val_mask = data.val_mask features = th.Tensor(data.features) in_feats = features.shape[1] labels = th.LongTensor(data.labels) n_classes = data.num_labels # Construct graph g = dgl.graph(data.graph.all_edges()) g.ndata['features'] = features prepare_mp(g) # Pack data data = train_mask, val_mask, in_feats, labels, n_classes, g if n_gpus == 1: run(0, n_gpus, args, devices, data) else: procs = [] for proc_id in range(n_gpus): p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, data)) p.start() procs.append(p) for p in procs: p.join()