Unverified Commit 17d604b5 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Feature] Allow using NCCL for communication in dgl.NodeEmbedding and dgl.SparseOptimizer (#2824)



* Split from NCCL PR

* Fix type in comment

* Expand documentation for sparse_all_to_all_push

* Restore previous behavior in example

* Re-work optimizer to use NCCL based on gradient location

* Allow for running with embedding on CPU but using NCCL for gradient exchange

* Optimize single partition case

* Fix pylint errors

* Add missing include

* fix gradient indexing

* Fix line continuation

* Migrate 'first_step'

* Skip tests without enough GPUs to run NCCL

* Improve empty tensor handling for pytorch 1.5

* Fix indentation

* Allow multiple NCCL communicator to coexist

* Improve handling of empty message

* Update python/dgl/nn/pytorch/sparse_emb.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>

* Update python/dgl/nn/pytorch/sparse_emb.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>

* Keepy empty tensor dimensionaless

* th.empty -> th.tensor

* Preserve shape for empty non-zero dimension tensors

* Use shared state, when embedding is shared

* Add support for gathering an embedding

* Fix typo

* Fix more typos

* Fix backend call

* Use NodeDataLoader to take advantage of ddp

* Update training script to share memory

* Only squeeze last dimension

* Better handle empty message

* Keep embedding on the target device GPU if dgl_sparse if false in RGCN example

* Fix typo in comment

* Add asserts

* Improve documentation in example
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 9497a9be
...@@ -123,54 +123,7 @@ def gen_norm(g): ...@@ -123,54 +123,7 @@ def gen_norm(g):
norm = norm.unsqueeze(1) norm = norm.unsqueeze(1)
g.edata['norm'] = norm g.edata['norm'] = norm
class NeighborSampler: def evaluate(model, embed_layer, eval_loader, node_feats, inv_target):
"""Neighbor sampler
Parameters
----------
g : DGLHeterograph
Full graph
target_idx : tensor
The target training node IDs in g
fanouts : list of int
Fanout of each hop starting from the seed nodes. If a fanout is None,
sample full neighbors.
"""
def __init__(self, g, target_idx, fanouts):
self.g = g
self.target_idx = target_idx
self.fanouts = fanouts
"""Do neighbor sample
Parameters
----------
seeds :
Seed nodes
Returns
-------
tensor
Seed nodes, also known as target nodes
blocks
Sampled subgraphs
"""
def sample_blocks(self, seeds):
blocks = []
etypes = []
norms = []
ntypes = []
seeds = th.tensor(seeds).long()
cur = self.target_idx[seeds]
for fanout in self.fanouts:
if fanout is None or fanout == -1:
frontier = dgl.in_subgraph(self.g, cur)
else:
frontier = dgl.sampling.sample_neighbors(self.g, cur, fanout)
block = dgl.to_block(frontier, cur)
gen_norm(block)
cur = block.srcdata[dgl.NID]
blocks.insert(0, block)
return seeds, blocks
def evaluate(model, embed_layer, eval_loader, node_feats):
model.eval() model.eval()
embed_layer.eval() embed_layer.eval()
eval_logits = [] eval_logits = []
...@@ -179,7 +132,11 @@ def evaluate(model, embed_layer, eval_loader, node_feats): ...@@ -179,7 +132,11 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
with th.no_grad(): with th.no_grad():
th.cuda.empty_cache() th.cuda.empty_cache()
for sample_data in tqdm.tqdm(eval_loader): for sample_data in tqdm.tqdm(eval_loader):
seeds, blocks = sample_data inputs, seeds, blocks = sample_data
seeds = inv_target[seeds]
for block in blocks:
gen_norm(block)
feats = embed_layer(blocks[0].srcdata[dgl.NID], feats = embed_layer(blocks[0].srcdata[dgl.NID],
blocks[0].srcdata['ntype'], blocks[0].srcdata['ntype'],
...@@ -197,7 +154,7 @@ def evaluate(model, embed_layer, eval_loader, node_feats): ...@@ -197,7 +154,7 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None): def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
dev_id = devices[proc_id] if devices[proc_id] != 'cpu' else -1 dev_id = devices[proc_id] if devices[proc_id] != 'cpu' else -1
g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \ g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \
train_idx, val_idx, test_idx, labels = dataset inv_target, train_idx, val_idx, test_idx, labels = dataset
if split is not None: if split is not None:
train_seed, val_seed, test_seed = split train_seed, val_seed, test_seed = split
train_idx = train_idx[train_seed] train_idx = train_idx[train_seed]
...@@ -206,48 +163,63 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None): ...@@ -206,48 +163,63 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
fanouts = [int(fanout) for fanout in args.fanout.split(',')] fanouts = [int(fanout) for fanout in args.fanout.split(',')]
node_tids = g.ndata[dgl.NTYPE] node_tids = g.ndata[dgl.NTYPE]
sampler = NeighborSampler(g, target_idx, fanouts)
loader = DataLoader(dataset=train_idx.numpy(), world_size = n_gpus
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
backend = 'nccl'
# using sparse embedding or using mix_cpu_gpu model (embedding model can not be stored in GPU)
if dev_id < 0 or args.dgl_sparse is False:
backend = 'gloo'
print("backend using {}".format(backend))
th.distributed.init_process_group(backend=backend,
init_method=dist_init_method,
world_size=world_size,
rank=proc_id)
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
loader = dgl.dataloading.NodeDataLoader(
g,
target_idx[train_idx],
sampler,
use_ddp=n_gpus > 1,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers)
# validation sampler # validation sampler
val_sampler = NeighborSampler(g, target_idx, fanouts) val_loader = dgl.dataloading.NodeDataLoader(
val_loader = DataLoader(dataset=val_idx.numpy(), g,
target_idx[val_idx],
sampler,
use_ddp=n_gpus > 1,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=val_sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers)
# test sampler # test sampler
test_sampler = NeighborSampler(g, target_idx, [None] * args.n_layers) test_sampler = dgl.dataloading.MultiLayerNeighborSampler([None] * args.n_layers)
test_loader = DataLoader(dataset=test_idx.numpy(), test_loader = dgl.dataloading.NodeDataLoader(
g,
target_idx[test_idx],
test_sampler,
use_ddp=n_gpus > 1,
batch_size=args.eval_batch_size, batch_size=args.eval_batch_size,
collate_fn=test_sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers)
world_size = n_gpus
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
backend = 'nccl'
# using sparse embedding or usig mix_cpu_gpu model (embedding model can not be stored in GPU)
if args.dgl_sparse is False:
backend = 'gloo'
print("backend using {}".format(backend))
th.distributed.init_process_group(backend=backend,
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
# node features # node features
# None for one-hot feature, if not none, it should be the feature tensor. # None for one-hot feature, if not none, it should be the feature tensor.
# #
embed_layer = RelGraphEmbedLayer(dev_id, embed_layer = RelGraphEmbedLayer(dev_id if args.embedding_gpu or not args.dgl_sparse else -1,
dev_id,
g.number_of_nodes(), g.number_of_nodes(),
node_tids, node_tids,
num_of_ntype, num_of_ntype,
...@@ -279,6 +251,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None): ...@@ -279,6 +251,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
if n_gpus > 1: if n_gpus > 1:
labels = labels.to(dev_id) labels = labels.to(dev_id)
if dev_id >= 0:
model.cuda(dev_id) model.cuda(dev_id)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
if args.dgl_sparse: if args.dgl_sparse:
...@@ -331,7 +304,14 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None): ...@@ -331,7 +304,14 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
embed_layer.train() embed_layer.train()
for i, sample_data in enumerate(loader): for i, sample_data in enumerate(loader):
seeds, blocks = sample_data input_nodes, seeds, blocks = sample_data
# map the seed nodes back to their type-specific ids, so that they
# can be used to look up their respective labels
seeds = inv_target[seeds]
for block in blocks:
gen_norm(block)
t0 = time.time() t0 = time.time()
feats = embed_layer(blocks[0].srcdata[dgl.NID], feats = embed_layer(blocks[0].srcdata[dgl.NID],
blocks[0].srcdata['ntype'], blocks[0].srcdata['ntype'],
...@@ -353,7 +333,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None): ...@@ -353,7 +333,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
forward_time.append(t1 - t0) forward_time.append(t1 - t0)
backward_time.append(t2 - t1) backward_time.append(t2 - t1)
train_acc = th.sum(logits.argmax(dim=1) == labels[seeds]).item() / len(seeds) train_acc = th.sum(logits.argmax(dim=1) == labels[seeds]).item() / len(seeds)
if i % 100 and proc_id == 0: if i % 100 == 0 and proc_id == 0:
print("Train Accuracy: {:.4f} | Train Loss: {:.4f}". print("Train Accuracy: {:.4f} | Train Loss: {:.4f}".
format(train_acc, loss.item())) format(train_acc, loss.item()))
gc.collect() gc.collect()
...@@ -379,7 +359,8 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None): ...@@ -379,7 +359,8 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
vstart = time.time() vstart = time.time()
if (queue is not None) or (proc_id == 0): if (queue is not None) or (proc_id == 0):
val_logits, val_seeds = evaluate(model, embed_layer, val_loader, node_feats) val_logits, val_seeds = evaluate(model, embed_layer, val_loader,
node_feats, inv_target)
if queue is not None: if queue is not None:
queue.put((val_logits, val_seeds)) queue.put((val_logits, val_seeds))
...@@ -407,7 +388,9 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None): ...@@ -407,7 +388,9 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, split, queue=None):
if epoch > 0 and do_test: if epoch > 0 and do_test:
tstart = time.time() tstart = time.time()
if (queue is not None) or (proc_id == 0): if (queue is not None) or (proc_id == 0):
test_logits, test_seeds = evaluate(model, embed_layer, test_loader, node_feats) test_logits, test_seeds = evaluate(model, embed_layer,
test_loader, node_feats,
inv_target)
if queue is not None: if queue is not None:
queue.put((test_logits, test_seeds)) queue.put((test_logits, test_seeds))
...@@ -532,6 +515,17 @@ def main(args, devices): ...@@ -532,6 +515,17 @@ def main(args, devices):
train_idx.share_memory_() train_idx.share_memory_()
val_idx.share_memory_() val_idx.share_memory_()
test_idx.share_memory_() test_idx.share_memory_()
# This is a graph with multiple node types, so we want a way to map
# our target node from their global node numberings, back to their
# numberings within their type. This is used when taking the nodes in a
# mini-batch, and looking up their type-specific labels
inv_target = th.empty(node_ids.shape,
dtype=node_ids.dtype)
inv_target.share_memory_()
inv_target[target_idx] = th.arange(0, target_idx.shape[0],
dtype=inv_target.dtype)
# Create csr/coo/csc formats before launching training processes with multi-gpu. # Create csr/coo/csc formats before launching training processes with multi-gpu.
# This avoids creating certain formats in each sub-process, which saves momory and CPU. # This avoids creating certain formats in each sub-process, which saves momory and CPU.
g.create_formats_() g.create_formats_()
...@@ -542,12 +536,12 @@ def main(args, devices): ...@@ -542,12 +536,12 @@ def main(args, devices):
if devices[0] == -1: if devices[0] == -1:
run(0, 0, n_cpus, args, ['cpu'], run(0, 0, n_cpus, args, ['cpu'],
(g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, (g, node_feats, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels), None, None) inv_target, train_idx, val_idx, test_idx, labels), None, None)
# gpu # gpu
elif n_gpus == 1: elif n_gpus == 1:
run(0, n_gpus, n_cpus, args, devices, run(0, n_gpus, n_cpus, args, devices,
(g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, (g, node_feats, num_of_ntype, num_classes, num_rels, target_idx,
train_idx, val_idx, test_idx, labels), None, None) inv_target, train_idx, val_idx, test_idx, labels), None, None)
# multi gpu # multi gpu
else: else:
queue = mp.Queue(n_gpus) queue = mp.Queue(n_gpus)
...@@ -577,8 +571,10 @@ def main(args, devices): ...@@ -577,8 +571,10 @@ def main(args, devices):
if (proc_id + 1) * tstseeds_per_proc < num_test_seeds \ if (proc_id + 1) * tstseeds_per_proc < num_test_seeds \
else num_test_seeds] else num_test_seeds]
p = mp.Process(target=run, args=(proc_id, n_gpus, n_cpus // n_gpus, args, devices, p = mp.Process(target=run, args=(proc_id, n_gpus, n_cpus // n_gpus, args, devices,
(g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, (g, node_feats, num_of_ntype,
train_idx, val_idx, test_idx, labels), num_classes, num_rels, target_idx,
inv_target, train_idx, val_idx,
test_idx, labels),
(proc_train_seeds, proc_valid_seeds, proc_test_seeds), (proc_train_seeds, proc_valid_seeds, proc_test_seeds),
queue)) queue))
p.start() p.start()
...@@ -626,6 +622,8 @@ def config(): ...@@ -626,6 +622,8 @@ def config():
help="Whether use low mem RelGraphCov") help="Whether use low mem RelGraphCov")
parser.add_argument("--dgl-sparse", default=False, action='store_true', parser.add_argument("--dgl-sparse", default=False, action='store_true',
help='Use sparse embedding for node embeddings.') help='Use sparse embedding for node embeddings.')
parser.add_argument("--embedding-gpu", default=False, action='store_true',
help='Store the node embeddings on the GPU.')
parser.add_argument('--node-feats', default=False, action='store_true', parser.add_argument('--node-feats', default=False, action='store_true',
help='Whether use node features') help='Whether use node features')
parser.add_argument('--layer-norm', default=False, action='store_true', parser.add_argument('--layer-norm', default=False, action='store_true',
......
...@@ -58,8 +58,10 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -58,8 +58,10 @@ class RelGraphEmbedLayer(nn.Module):
r"""Embedding layer for featureless heterograph. r"""Embedding layer for featureless heterograph.
Parameters Parameters
---------- ----------
dev_id : int storage_dev_id : int
Device to run the layer. The device to store the weights of the layer.
out_dev_id : int
Device to return the output embeddings on.
num_nodes : int num_nodes : int
Number of nodes. Number of nodes.
node_tides : tensor node_tides : tensor
...@@ -75,7 +77,8 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -75,7 +77,8 @@ class RelGraphEmbedLayer(nn.Module):
If true, use dgl.nn.NodeEmbedding otherwise use torch.nn.Embedding If true, use dgl.nn.NodeEmbedding otherwise use torch.nn.Embedding
""" """
def __init__(self, def __init__(self,
dev_id, storage_dev_id,
out_dev_id,
num_nodes, num_nodes,
node_tids, node_tids,
num_of_ntype, num_of_ntype,
...@@ -83,7 +86,9 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -83,7 +86,9 @@ class RelGraphEmbedLayer(nn.Module):
embed_size, embed_size,
dgl_sparse=False): dgl_sparse=False):
super(RelGraphEmbedLayer, self).__init__() super(RelGraphEmbedLayer, self).__init__()
self.dev_id = th.device(dev_id if dev_id >= 0 else 'cpu') self.storage_dev_id = th.device( \
storage_dev_id if storage_dev_id >= 0 else 'cpu')
self.out_dev_id = th.device(out_dev_id if out_dev_id >= 0 else 'cpu')
self.embed_size = embed_size self.embed_size = embed_size
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.dgl_sparse = dgl_sparse self.dgl_sparse = dgl_sparse
...@@ -97,14 +102,16 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -97,14 +102,16 @@ class RelGraphEmbedLayer(nn.Module):
if isinstance(input_size[ntype], int): if isinstance(input_size[ntype], int):
if dgl_sparse: if dgl_sparse:
self.node_embeds[str(ntype)] = dgl.nn.NodeEmbedding(input_size[ntype], embed_size, name=str(ntype), self.node_embeds[str(ntype)] = dgl.nn.NodeEmbedding(input_size[ntype], embed_size, name=str(ntype),
init_func=initializer) init_func=initializer, device=self.storage_dev_id)
else: else:
sparse_emb = th.nn.Embedding(input_size[ntype], embed_size, sparse=True) sparse_emb = th.nn.Embedding(input_size[ntype], embed_size, sparse=True)
sparse_emb.cuda(self.storage_dev_id)
nn.init.uniform_(sparse_emb.weight, -1.0, 1.0) nn.init.uniform_(sparse_emb.weight, -1.0, 1.0)
self.node_embeds[str(ntype)] = sparse_emb self.node_embeds[str(ntype)] = sparse_emb
else: else:
input_emb_size = input_size[ntype].shape[1] input_emb_size = input_size[ntype].shape[1]
embed = nn.Parameter(th.Tensor(input_emb_size, self.embed_size)) embed = nn.Parameter(th.empty([input_emb_size, self.embed_size],
device=self.storage_dev_id))
nn.init.xavier_uniform_(embed) nn.init.xavier_uniform_(embed)
self.embeds[str(ntype)] = embed self.embeds[str(ntype)] = embed
...@@ -136,16 +143,24 @@ class RelGraphEmbedLayer(nn.Module): ...@@ -136,16 +143,24 @@ class RelGraphEmbedLayer(nn.Module):
tensor tensor
embeddings as the input of the next layer embeddings as the input of the next layer
""" """
tsd_ids = node_ids.to(self.dev_id) embeds = th.empty(node_ids.shape[0], self.embed_size, device=self.out_dev_id)
embeds = th.empty(node_ids.shape[0], self.embed_size, device=self.dev_id)
# transfer input to the correct device
type_ids = type_ids.to(self.storage_dev_id)
node_tids = node_tids.to(self.storage_dev_id)
# build locs first
locs = [None for i in range(self.num_of_ntype)]
for ntype in range(self.num_of_ntype):
locs[ntype] = (node_tids == ntype).nonzero().squeeze(-1)
for ntype in range(self.num_of_ntype): for ntype in range(self.num_of_ntype):
loc = node_tids == ntype loc = locs[ntype]
if isinstance(features[ntype], int): if isinstance(features[ntype], int):
if self.dgl_sparse: if self.dgl_sparse:
embeds[loc] = self.node_embeds[str(ntype)](type_ids[loc], self.dev_id) embeds[loc] = self.node_embeds[str(ntype)](type_ids[loc], self.out_dev_id)
else: else:
embeds[loc] = self.node_embeds[str(ntype)](type_ids[loc]).to(self.dev_id) embeds[loc] = self.node_embeds[str(ntype)](type_ids[loc]).to(self.out_dev_id)
else: else:
embeds[loc] = features[ntype][type_ids[loc]].to(self.dev_id) @ self.embeds[str(ntype)].to(self.dev_id) embeds[loc] = features[ntype][type_ids[loc]].to(self.out_dev_id) @ self.embeds[str(ntype)].to(self.out_dev_id)
return embeds return embeds
...@@ -338,6 +338,13 @@ def zerocopy_from_dgl_ndarray(data): ...@@ -338,6 +338,13 @@ def zerocopy_from_dgl_ndarray(data):
# The issue will be fixed in v1.6 and later. # The issue will be fixed in v1.6 and later.
return th.tensor([], dtype=getattr(th, data.dtype), return th.tensor([], dtype=getattr(th, data.dtype),
device=to_backend_ctx(data.ctx)) device=to_backend_ctx(data.ctx))
elif len(data.shape) == 0 or builtins.min(data.shape) == 0:
# Workaround the same issue as above, but preserve the shape of the
# empty tensor. This is needed by the sparse optimizer when one of
# processors may receive no gradients to update, but we want to keep
# the dimension of the embedding.
return th.empty(data.shape, dtype=getattr(th, data.dtype),
device=to_backend_ctx(data.ctx))
else: else:
return dlpack.from_dlpack(data.to_dlpack()) return dlpack.from_dlpack(data.to_dlpack())
......
...@@ -88,7 +88,7 @@ class Communicator(object): ...@@ -88,7 +88,7 @@ class Communicator(object):
The 1D set of indices to send to other processors. The 1D set of indices to send to other processors.
value : tensor value : tensor
The multi-dimension set of values to send to other processors. The multi-dimension set of values to send to other processors.
The 0th dimension must match that of `idx`. The first dimension must match that of `idx`.
partition : NDArrayPartition partition : NDArrayPartition
The object containing information for assigning indices to The object containing information for assigning indices to
processors. processors.
...@@ -137,7 +137,7 @@ class Communicator(object): ...@@ -137,7 +137,7 @@ class Communicator(object):
def sparse_all_to_all_pull(self, req_idx, value, partition): def sparse_all_to_all_pull(self, req_idx, value, partition):
""" Perform an all-to-all-v operation, where by all processors request """ Perform an all-to-all-v operation, where by all processors request
the values corresponding to ther set of indices. the values corresponding to their set of indices.
Parameters Parameters
---------- ----------
......
...@@ -3,8 +3,11 @@ from datetime import timedelta ...@@ -3,8 +3,11 @@ from datetime import timedelta
import torch as th import torch as th
from ...backend import pytorch as F from ...backend import pytorch as F
from ...utils import get_shared_mem_array, create_shared_mem_array from ...utils import get_shared_mem_array, create_shared_mem_array
from ...cuda import nccl
from ...partition import NDArrayPartition
_STORE = None _STORE = None
_COMM = None
class NodeEmbedding: # NodeEmbedding class NodeEmbedding: # NodeEmbedding
'''Class for storing node embeddings. '''Class for storing node embeddings.
...@@ -36,6 +39,11 @@ class NodeEmbedding: # NodeEmbedding ...@@ -36,6 +39,11 @@ class NodeEmbedding: # NodeEmbedding
init_func : callable, optional init_func : callable, optional
The function to create the initial data. If the init function is not provided, The function to create the initial data. If the init function is not provided,
the values of the embeddings are initialized to zero. the values of the embeddings are initialized to zero.
device : th.device
Device to store the embeddings on.
parittion : NDArrayPartition
The partition to use to distributed the embeddings between
processes.
Examples Examples
-------- --------
...@@ -58,8 +66,12 @@ class NodeEmbedding: # NodeEmbedding ...@@ -58,8 +66,12 @@ class NodeEmbedding: # NodeEmbedding
''' '''
def __init__(self, num_embeddings, embedding_dim, name, def __init__(self, num_embeddings, embedding_dim, name,
init_func=None): init_func=None, device=None, partition=None):
global _STORE global _STORE
global _COMM
if device is None:
device = th.device('cpu')
# Check whether it is multi-gpu training or not. # Check whether it is multi-gpu training or not.
if th.distributed.is_initialized(): if th.distributed.is_initialized():
...@@ -70,32 +82,76 @@ class NodeEmbedding: # NodeEmbedding ...@@ -70,32 +82,76 @@ class NodeEmbedding: # NodeEmbedding
world_size = 0 world_size = 0
self._rank = rank self._rank = rank
self._world_size = world_size self._world_size = world_size
self._store = None
self._comm = None
self._partition = partition
host_name = '127.0.0.1' host_name = '127.0.0.1'
port = 12346 port = 12346
if rank >= 0:
# for multi-gpu training, setup a TCPStore for
# embeding status synchronization across GPU processes
if _STORE is None:
_STORE = th.distributed.TCPStore(
host_name, port, world_size, rank == 0, timedelta(seconds=10*60))
self._store = _STORE
# embeddings is stored in CPU memory.
if th.device(device) == th.device('cpu'):
if rank <= 0: if rank <= 0:
emb = create_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32) emb = create_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32)
if init_func is not None: if init_func is not None:
emb = init_func(emb) emb = init_func(emb)
if rank == 0: # the master gpu process if rank == 0: # the master gpu process
# for multi-gpu training, setup a TCPStore for
# embeding status synchronization across GPU processes
if _STORE is None:
_STORE = th.distributed.TCPStore(
host_name, port, world_size, True, timedelta(seconds=10*60))
for _ in range(1, world_size): for _ in range(1, world_size):
# send embs # send embs
_STORE.set(name, name) self._store.set(name, name)
elif rank > 0: elif rank > 0:
# receive # receive
if _STORE is None: self._store.wait([name])
_STORE = th.distributed.TCPStore(
host_name, port, world_size, False, timedelta(seconds=10*60))
_STORE.wait([name])
emb = get_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32) emb = get_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32)
self._store = _STORE
self._tensor = emb self._tensor = emb
else: # embeddings is stored in GPU memory.
# setup nccl communicator
if _COMM is None:
if rank < 0:
_COMM = nccl.Communicator(1, 0, nccl.UniqueId())
else:
# needs to be set for nccl to work
th.cuda.set_device(device)
if rank == 0:
# root process broadcasts nccl id
nccl_id = nccl.UniqueId()
self._store.set('nccl_root_id', str(nccl_id))
else:
nccl_id = nccl.UniqueId(self._store.get('nccl_root_id'))
_COMM = nccl.Communicator(self._world_size, self._rank,
nccl_id)
if self._rank == 0:
# clear the store entry for future communicators
self._store.delete_key('nccl_root_id')
th.distributed.barrier()
self._comm = _COMM
if not self._partition:
# for communication we need a partition
self._partition = NDArrayPartition(
num_embeddings,
self._world_size if self._world_size > 0 else 1,
mode='remainder')
# create local tensors for the weights
local_size = self._partition.local_size(self._comm.rank())
# TODO(dlasalle): support 16-bit/half embeddings
emb = th.empty([local_size, embedding_dim], dtype=th.float32,
requires_grad=False, device=device)
if init_func:
emb = init_func(emb)
self._tensor = emb
self._num_embeddings = num_embeddings self._num_embeddings = num_embeddings
self._embedding_dim = embedding_dim self._embedding_dim = embedding_dim
self._name = name self._name = name
...@@ -109,10 +165,19 @@ class NodeEmbedding: # NodeEmbedding ...@@ -109,10 +165,19 @@ class NodeEmbedding: # NodeEmbedding
device : th.device device : th.device
Target device to put the collected embeddings. Target device to put the collected embeddings.
""" """
if not self._comm or self._comm.size() == 1:
emb = self._tensor[node_ids].to(device) emb = self._tensor[node_ids].to(device)
else:
if self.world_size > 0:
emb = self._comm.sparse_all_to_all_pull(
node_ids, self._tensor, self._partition)
else:
emb = self._tensor[node_ids]
emb = emb.to(device)
if F.is_recording(): if F.is_recording():
emb = F.attach_grad(emb) emb = F.attach_grad(emb)
self._trace.append((node_ids.to(device, non_blocking=True), emb)) self._trace.append((node_ids.to(device), emb))
return emb return emb
@property @property
...@@ -127,6 +192,31 @@ class NodeEmbedding: # NodeEmbedding ...@@ -127,6 +192,31 @@ class NodeEmbedding: # NodeEmbedding
""" """
return self._store return self._store
@property
def comm(self):
"""Return dgl.cuda.nccl.Communicator for data
sharing across processes.
Returns
-------
dgl.cuda.nccl.Communicator
Communicator used for data sharing.
"""
return self._comm
@property
def partition(self):
"""Return the partition identifying how the tensor is split across
processes.
Returns
-------
String
The mode.
"""
return self._partition
@property @property
def rank(self): def rank(self):
"""Return rank of current process. """Return rank of current process.
...@@ -244,3 +334,47 @@ class NodeEmbedding: # NodeEmbedding ...@@ -244,3 +334,47 @@ class NodeEmbedding: # NodeEmbedding
The tensor storing the node embeddings The tensor storing the node embeddings
""" """
return self._tensor return self._tensor
def gather_embedding(self):
"""Return a copy of the embedding stored in CPU memory. If this is a
multi-processing instance, the tensor will be returned in shared
memory. If the embedding is currently stored on multiple GPUs, all
processes must call this method in the same order.
Returns
-------
torch.Tensor
The tensor storing the node embeddings.
"""
if self._partition:
if self._world_size == 0:
# non-multiprocessing
return self._tensor.to(th.device('cpu'))
else:
# create a shared memory tensor
shared_name = self._name + "_gather"
if self._rank == 0:
# root process creates shared memory
emb = create_shared_mem_array(
shared_name,
(self._num_embeddings, self._embedding_dim),
self._tensor.dtype)
self._store.set(shared_name, shared_name)
else:
self._store.wait([shared_name])
emb = get_shared_mem_array(
shared_name, (self._num_embeddings, self._embedding_dim),
self._tensor.dtype)
# need to map indices and slice into existing tensor
idxs = self._partition.map_to_global(
F.arange(0, self._tensor.shape[0],
ctx=F.context(self._tensor)),
self._rank).to(emb.device)
emb[idxs] = self._tensor.to(emb.device)
# wait for all processes to finish
th.distributed.barrier()
return emb
else:
# already stored in CPU memory
return self._tensor
...@@ -5,6 +5,8 @@ import torch as th ...@@ -5,6 +5,8 @@ import torch as th
from ...utils import get_shared_mem_array, create_shared_mem_array from ...utils import get_shared_mem_array, create_shared_mem_array
from ...nn.pytorch import NodeEmbedding from ...nn.pytorch import NodeEmbedding
from ...cuda import nccl
from ...partition import NDArrayPartition
class SparseGradOptimizer(abc.ABC): class SparseGradOptimizer(abc.ABC):
r''' The abstract sparse optimizer. r''' The abstract sparse optimizer.
...@@ -26,10 +28,14 @@ class SparseGradOptimizer(abc.ABC): ...@@ -26,10 +28,14 @@ class SparseGradOptimizer(abc.ABC):
self._shared_cache = {} self._shared_cache = {}
self._clean_grad = False self._clean_grad = False
self._opt_meta = {} self._opt_meta = {}
self._comm = None
self._first_step = True
self._device = None
# hold released shared memory to let other process to munmap it first # hold released shared memory to let other process to munmap it first
# otherwise it will crash the training # otherwise it will crash the training
self.shmem_buffer_holder = [] self.shmem_buffer_holder = []
# if we are using shared memory for communication
for emb in params: for emb in params:
assert isinstance(emb, NodeEmbedding), \ assert isinstance(emb, NodeEmbedding), \
'DGL SparseOptimizer only supports dgl.nn.NodeEmbedding' 'DGL SparseOptimizer only supports dgl.nn.NodeEmbedding'
...@@ -42,11 +48,86 @@ class SparseGradOptimizer(abc.ABC): ...@@ -42,11 +48,86 @@ class SparseGradOptimizer(abc.ABC):
'MultiGPU rank for each embedding should be same.' 'MultiGPU rank for each embedding should be same.'
assert self._world_size == emb.world_size, \ assert self._world_size == emb.world_size, \
'MultiGPU world_size for each embedding should be same.' 'MultiGPU world_size for each embedding should be same.'
assert not self._rank is None
assert not self._world_size is None
def step(self):
''' The step function.
The step function is invoked at the end of every batch to update embeddings
'''
# on the first step, check to see if the grads are on the GPU
if self._first_step:
for emb in self._params:
for _, data in emb._trace:
if data.grad.data.device.type == 'cuda':
# create a communicator
if self._device:
assert self._device == data.grad.device, \
"All gradients must be on the same device"
else:
self._device = data.grad.device
else:
assert not self._device, \
"All gradients must be on the same device"
if self._device:
# device is only set if the grads are on a GPU
self._comm_setup()
else:
self._shared_setup()
self.setup(self._params)
self._first_step = False
if self._comm:
self._comm_step()
else:
self._shared_step()
def setup(self, params):
''' This is function where subclasses can perform any setup they need
to. It will be called during the first step, and communicators or
shared memory will have been setup before this call.
Parameters
----------
params : list of NodeEmbedding
The list of NodeEmbeddings.
'''
def _comm_setup(self):
# find a store to communicate the unique id through
if len(self._params) > 0:
store = self._params[0].store
if self._rank < 0:
self._comm = nccl.Communicator(1, 0, nccl.UniqueId())
else:
th.cuda.set_device(self._device)
if self._rank == 0:
# root process broadcasts nccl id
nccl_id = nccl.UniqueId()
uid = str(nccl_id)
store.set('nccl_root_id', uid)
else:
uid = store.get('nccl_root_id')
nccl_id = nccl.UniqueId(uid)
# needs to be set for nccl to work
self._comm = nccl.Communicator(self._world_size,
self._rank,
nccl_id)
if self._rank == 0:
# clear the store entry for future communicators
store.delete_key('nccl_root_id')
th.distributed.barrier()
def _shared_setup(self):
for emb in self._params:
emb_name = emb.name emb_name = emb.name
if self._rank == 0: # the master gpu process if self._rank == 0: # the master gpu process
opt_meta = create_shared_mem_array(emb_name+'_opt_meta', \ opt_meta = create_shared_mem_array(emb_name+'_opt_meta', \
(self._world_size, self._world_size), th.int32).zero_() (self._world_size, self._world_size), th.int32).zero_()
if self._rank == 0: if self._rank == 0:
emb.store.set(emb_name+'_opt_meta', emb_name) emb.store.set(emb_name+'_opt_meta', emb_name)
self._opt_meta[emb_name] = opt_meta self._opt_meta[emb_name] = opt_meta
...@@ -57,11 +138,64 @@ class SparseGradOptimizer(abc.ABC): ...@@ -57,11 +138,64 @@ class SparseGradOptimizer(abc.ABC):
(self._world_size, self._world_size), th.int32) (self._world_size, self._world_size), th.int32)
self._opt_meta[emb_name] = opt_meta self._opt_meta[emb_name] = opt_meta
def step(self): def _comm_step(self):
''' The step function. comm = self._comm
with th.no_grad():
idx_in = {}
grad_in = {}
for emb in self._params: # pylint: disable=too-many-nested-blocks
emb_name = emb.name
partition = emb.partition
The step function is invoked at the end of every batch to update embeddings if not partition:
''' # use default partitioning
partition = NDArrayPartition(
emb.num_embeddings,
self._world_size if self._world_size > 0 else 1,
mode='remainder')
# we need to combine gradients from multiple forward paths
if len(emb._trace) == 0:
idx = th.zeros((0,), dtype=th.long, device=self._device)
grad = th.zeros((0, emb.embedding_dim),
dtype=th.float32,
device=self._device)
elif len(emb._trace) == 1:
# the special case where we can use the tensors as is
# without any memcpy's
idx, grad = emb._trace[0]
grad = grad.grad.data
else:
idx = []
grad = []
for i, data in emb._trace:
idx.append(i)
grad.append(data.grad.data)
idx = th.cat(idx, dim=0)
grad = th.cat(grad, dim=0)
idx_in[emb_name], grad_in[emb_name] = \
comm.sparse_all_to_all_push(
idx, grad, partition=partition)
if emb.partition:
# if the embedding is partitioned, map back to indexes
# into the local tensor
idx_in[emb_name] = partition.map_to_local(idx_in[emb_name])
if self._clean_grad:
# clean gradient track
for emb in self._params:
emb.reset_trace()
self._clean_grad = False
for emb in self._params:
emb_name = emb.name
idx = idx_in[emb_name]
grad = grad_in[emb_name]
self.update(idx, grad, emb)
def _shared_step(self):
with th.no_grad(): with th.no_grad():
# Frequently alloc and free shared memory to hold intermediate tensor is expensive # Frequently alloc and free shared memory to hold intermediate tensor is expensive
# We cache shared memory buffers in shared_emb. # We cache shared memory buffers in shared_emb.
...@@ -280,13 +414,17 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -280,13 +414,17 @@ class SparseAdagrad(SparseGradOptimizer):
def __init__(self, params, lr, eps=1e-10): def __init__(self, params, lr, eps=1e-10):
super(SparseAdagrad, self).__init__(params, lr) super(SparseAdagrad, self).__init__(params, lr)
self._eps = eps self._eps = eps
def setup(self, params):
# We need to register a state sum for each embedding in the kvstore. # We need to register a state sum for each embedding in the kvstore.
for emb in params: for emb in params:
assert isinstance(emb, NodeEmbedding), \ assert isinstance(emb, NodeEmbedding), \
'SparseAdagrad only supports dgl.nn.NodeEmbedding' 'SparseAdagrad only supports dgl.nn.NodeEmbedding'
if self._rank <= 0:
emb_name = emb.name emb_name = emb.name
if th.device(emb.emb_tensor.device) == th.device('cpu'):
# if our embedding is on the CPU, our state also has to be
if self._rank <= 0:
state = create_shared_mem_array(emb_name+'_state', \ state = create_shared_mem_array(emb_name+'_state', \
emb.weight.shape, th.float32).zero_() emb.weight.shape, th.float32).zero_()
if self._rank == 0: if self._rank == 0:
...@@ -294,10 +432,15 @@ class SparseAdagrad(SparseGradOptimizer): ...@@ -294,10 +432,15 @@ class SparseAdagrad(SparseGradOptimizer):
emb.store.set(emb_name+'_opt', emb_name) emb.store.set(emb_name+'_opt', emb_name)
elif self._rank > 0: elif self._rank > 0:
# receive # receive
emb_name = emb.name
emb.store.wait([emb_name+'_opt']) emb.store.wait([emb_name+'_opt'])
state = get_shared_mem_array(emb_name+'_state', \ state = get_shared_mem_array(emb_name+'_state', \
emb.weight.shape, th.float32) emb.weight.shape, th.float32)
else:
# distributed state on on gpu
state = th.empty(
emb.emb_tensor.shape,
dtype=th.float32,
device=emb.emb_tensor.device).zero_()
emb.set_optm_state(state) emb.set_optm_state(state)
def update(self, idx, grad, emb): def update(self, idx, grad, emb):
...@@ -386,13 +529,16 @@ class SparseAdam(SparseGradOptimizer): ...@@ -386,13 +529,16 @@ class SparseAdam(SparseGradOptimizer):
self._beta1 = betas[0] self._beta1 = betas[0]
self._beta2 = betas[1] self._beta2 = betas[1]
self._eps = eps self._eps = eps
def setup(self, params):
# We need to register a state sum for each embedding in the kvstore. # We need to register a state sum for each embedding in the kvstore.
for emb in params: for emb in params:
assert isinstance(emb, NodeEmbedding), \ assert isinstance(emb, NodeEmbedding), \
'SparseAdam only supports dgl.nn.NodeEmbedding' 'SparseAdam only supports dgl.nn.NodeEmbedding'
if self._rank <= 0:
emb_name = emb.name emb_name = emb.name
if th.device(emb.emb_tensor.device) == th.device('cpu'):
# if our embedding is on the CPU, our state also has to be
if self._rank <= 0:
state_step = create_shared_mem_array(emb_name+'_step', \ state_step = create_shared_mem_array(emb_name+'_step', \
(emb.weight.shape[0],), th.float32).zero_() (emb.weight.shape[0],), th.float32).zero_()
state_mem = create_shared_mem_array(emb_name+'_mem', \ state_mem = create_shared_mem_array(emb_name+'_mem', \
...@@ -400,12 +546,10 @@ class SparseAdam(SparseGradOptimizer): ...@@ -400,12 +546,10 @@ class SparseAdam(SparseGradOptimizer):
state_power = create_shared_mem_array(emb_name+'_power', \ state_power = create_shared_mem_array(emb_name+'_power', \
emb.weight.shape, th.float32).zero_() emb.weight.shape, th.float32).zero_()
if self._rank == 0: if self._rank == 0:
emb_name = emb.name
if self._world_size > 1: if self._world_size > 1:
emb.store.set(emb_name+'_opt', emb_name) emb.store.set(emb_name+'_opt', emb_name)
elif self._rank > 0: elif self._rank > 0:
# receive # receive
emb_name = emb.name
emb.store.wait([emb_name+'_opt']) emb.store.wait([emb_name+'_opt'])
state_step = get_shared_mem_array(emb_name+'_step', \ state_step = get_shared_mem_array(emb_name+'_step', \
(emb.weight.shape[0],), th.float32) (emb.weight.shape[0],), th.float32)
...@@ -413,7 +557,20 @@ class SparseAdam(SparseGradOptimizer): ...@@ -413,7 +557,20 @@ class SparseAdam(SparseGradOptimizer):
emb.weight.shape, th.float32) emb.weight.shape, th.float32)
state_power = get_shared_mem_array(emb_name+'_power', \ state_power = get_shared_mem_array(emb_name+'_power', \
emb.weight.shape, th.float32) emb.weight.shape, th.float32)
else:
# distributed state on on gpu
state_step = th.empty(
[emb.emb_tensor.shape[0]],
dtype=th.float32,
device=emb.emb_tensor.device).zero_()
state_mem = th.empty(
emb.emb_tensor.shape,
dtype=th.float32,
device=emb.emb_tensor.device).zero_()
state_power = th.empty(
emb.emb_tensor.shape,
dtype=th.float32,
device=emb.emb_tensor.device).zero_()
state = (state_step, state_mem, state_power) state = (state_step, state_mem, state_power)
emb.set_optm_state(state) emb.set_optm_state(state)
......
...@@ -420,5 +420,23 @@ class NDArrayPartition(object): ...@@ -420,5 +420,23 @@ class NDArrayPartition(object):
""" """
return self._partition return self._partition
def local_size(self, part):
""" Get the number of rows/items assigned to the given part.
"""
return _CAPI_DGLNDArrayPartitionGetPartSize(self._partition, part)
def map_to_local(self, idxs):
""" Convert the set of global indices to local indices
"""
return F.zerocopy_from_dgl_ndarray(_CAPI_DGLNDArrayPartitionMapToLocal(
self._partition,
F.zerocopy_to_dgl_ndarray(idxs)))
def map_to_global(self, idxs, part_id):
""" Convert the set of local indices ot global indices
"""
return F.zerocopy_from_dgl_ndarray(_CAPI_DGLNDArrayPartitionMapToGlobal(
self._partition, F.zerocopy_to_dgl_ndarray(idxs), part_id))
_init_api("dgl.partition") _init_api("dgl.partition")
...@@ -23,6 +23,7 @@ template<typename IdType> __global__ void _MapProcByRemainder( ...@@ -23,6 +23,7 @@ template<typename IdType> __global__ void _MapProcByRemainder(
const int64_t num_index, const int64_t num_index,
const int64_t num_proc, const int64_t num_proc,
IdType * const proc_id) { IdType * const proc_id) {
assert(num_index <= gridDim.x*blockDim.x);
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x; const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (idx < num_index) { if (idx < num_index) {
...@@ -36,6 +37,7 @@ __global__ void _MapProcByMaskRemainder( ...@@ -36,6 +37,7 @@ __global__ void _MapProcByMaskRemainder(
const int64_t num_index, const int64_t num_index,
const IdType mask, const IdType mask,
IdType * const proc_id) { IdType * const proc_id) {
assert(num_index <= gridDim.x*blockDim.x);
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x; const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (idx < num_index) { if (idx < num_index) {
...@@ -49,6 +51,7 @@ __global__ void _MapLocalIndexByRemainder( ...@@ -49,6 +51,7 @@ __global__ void _MapLocalIndexByRemainder(
IdType * const out, IdType * const out,
const int64_t num_items, const int64_t num_items,
const int comm_size) { const int comm_size) {
assert(num_items <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x; const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
if (idx < num_items) { if (idx < num_items) {
...@@ -56,6 +59,23 @@ __global__ void _MapLocalIndexByRemainder( ...@@ -56,6 +59,23 @@ __global__ void _MapLocalIndexByRemainder(
} }
} }
template<typename IdType>
__global__ void _MapGlobalIndexByRemainder(
const IdType * const in,
IdType * const out,
const int part_id,
const int64_t num_items,
const int comm_size) {
assert(num_items <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
assert(part_id < comm_size);
if (idx < num_items) {
out[idx] = (in[idx] * comm_size) + part_id;
}
}
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, NDArray> std::pair<IdArray, NDArray>
GeneratePermutationFromRemainder( GeneratePermutationFromRemainder(
...@@ -80,6 +100,7 @@ GeneratePermutationFromRemainder( ...@@ -80,6 +100,7 @@ GeneratePermutationFromRemainder(
return result; return result;
} }
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8);
result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx); result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx);
int64_t * out_counts = static_cast<int64_t*>(result.second->data); int64_t * out_counts = static_cast<int64_t*>(result.second->data);
if (num_in == 0) { if (num_in == 0) {
...@@ -117,7 +138,6 @@ GeneratePermutationFromRemainder( ...@@ -117,7 +138,6 @@ GeneratePermutationFromRemainder(
// then create a permutation array that groups processors together by // then create a permutation array that groups processors together by
// performing a radix sort // performing a radix sort
Workspace<IdType> proc_id_out(device, ctx, num_in); Workspace<IdType> proc_id_out(device, ctx, num_in);
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8);
IdType * perm_out = static_cast<IdType*>(result.first->data); IdType * perm_out = static_cast<IdType*>(result.first->data);
{ {
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx); IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
...@@ -199,6 +219,7 @@ IdArray MapToLocalFromRemainder( ...@@ -199,6 +219,7 @@ IdArray MapToLocalFromRemainder(
const auto& ctx = global_idx->ctx; const auto& ctx = global_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
if (num_parts > 1) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx, IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
sizeof(IdType)*8); sizeof(IdType)*8);
...@@ -217,6 +238,10 @@ IdArray MapToLocalFromRemainder( ...@@ -217,6 +238,10 @@ IdArray MapToLocalFromRemainder(
num_parts); num_parts);
return local_idx; return local_idx;
} else {
// no mapping to be done
return global_idx;
}
} }
template IdArray template IdArray
...@@ -228,6 +253,57 @@ MapToLocalFromRemainder<kDLGPU, int64_t>( ...@@ -228,6 +253,57 @@ MapToLocalFromRemainder<kDLGPU, int64_t>(
int num_parts, int num_parts,
IdArray in_idx); IdArray in_idx);
template <DLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder(
const int num_parts,
IdArray local_idx,
const int part_id) {
CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id <<
"/" << num_parts;
CHECK_GE(part_id, 0) << "Invalid partition id " << part_id <<
"/" << num_parts;
const auto& ctx = local_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
if (num_parts > 1) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
sizeof(IdType)*8);
const dim3 block(128);
const dim3 grid((local_idx->shape[0] +block.x-1)/block.x);
CUDA_KERNEL_CALL(
_MapGlobalIndexByRemainder,
grid,
block,
0,
stream,
static_cast<const IdType*>(local_idx->data),
static_cast<IdType*>(global_idx->data),
part_id,
global_idx->shape[0],
num_parts);
return global_idx;
} else {
// no mapping to be done
return local_idx;
}
}
template IdArray
MapToGlobalFromRemainder<kDLGPU, int32_t>(
int num_parts,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRemainder<kDLGPU, int64_t>(
int num_parts,
IdArray in_idx,
int part_id);
} // namespace impl } // namespace impl
} // namespace partition } // namespace partition
......
...@@ -78,6 +78,31 @@ class RemainderPartition : public NDArrayPartition { ...@@ -78,6 +78,31 @@ class RemainderPartition : public NDArrayPartition {
// should be unreachable // should be unreachable
return IdArray{}; return IdArray{};
} }
IdArray MapToGlobal(
IdArray in_idx,
const int part_id) const override {
auto ctx = in_idx->ctx;
#ifdef DGL_USE_CUDA
if (ctx.device_type == kDLGPU) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToGlobalFromRemainder<kDLGPU, IdType>(
NumParts(), in_idx, part_id);
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return IdArray{};
}
int64_t PartSize(const int part_id) const override {
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id << ") for "
"partition of size " << NumParts() << ".";
return ArraySize() / NumParts() + (part_id < ArraySize() % NumParts());
}
}; };
NDArrayPartitionRef CreatePartitionRemainderBased( NDArrayPartitionRef CreatePartitionRemainderBased(
...@@ -95,5 +120,31 @@ DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased") ...@@ -95,5 +120,31 @@ DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
*rv = CreatePartitionRemainderBased(array_size, num_parts); *rv = CreatePartitionRemainderBased(array_size, num_parts);
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
int part_id = args[1];
*rv = part->PartSize(part_id);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToLocal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
*rv = part->MapToLocal(idxs);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToGlobal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
const int part_id = args[2];
*rv = part->MapToGlobal(idxs, part_id);
});
} // namespace partition } // namespace partition
} // namespace dgl } // namespace dgl
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#define DGL_PARTITION_NDARRAY_PARTITION_H_ #define DGL_PARTITION_NDARRAY_PARTITION_H_
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <utility> #include <utility>
...@@ -64,6 +65,29 @@ class NDArrayPartition : public runtime::Object { ...@@ -64,6 +65,29 @@ class NDArrayPartition : public runtime::Object {
virtual IdArray MapToLocal( virtual IdArray MapToLocal(
IdArray in_idx) const = 0; IdArray in_idx) const = 0;
/**
* @brief Generate the global indices (the numbering unique across all
* processors) from a set of local indices.
*
* @param in_idx The local indices.
* @param part_id The part id.
*
* @return The global indices.
*/
virtual IdArray MapToGlobal(
IdArray in_idx,
int part_id) const = 0;
/**
* @brief Get the number of rows/items assigned to the given part.
*
* @param part_id The part id.
*
* @return The size.
*/
virtual int64_t PartSize(
int part_id) const = 0;
/** /**
* @brief Get the first dimension of the partitioned array. * @brief Get the first dimension of the partitioned array.
* *
......
...@@ -51,6 +51,25 @@ IdArray MapToLocalFromRemainder( ...@@ -51,6 +51,25 @@ IdArray MapToLocalFromRemainder(
int num_parts, int num_parts,
IdArray global_idx); IdArray global_idx);
/**
* @brief Generate the set of global indices from the local indices, using
* remainder. That is, for each index `i` in `local_idx`, the global index
* is computed as `local_idx[i] * num_parts + part_id`.
*
* @tparam XPU The type of device to run on.
* @tparam IdType The type of the index.
* @param num_parts The number parts the array id divided into.
* @param local_idx The array of local indices to map.
* @param part_id The id of the current part.
*
* @return The array of global indices.
*/
template <DLDeviceType XPU, typename IdType>
IdArray MapToGlobalFromRemainder(
int num_parts,
IdArray local_idx,
int part_id);
} // namespace impl } // namespace impl
} // namespace partition } // namespace partition
......
...@@ -138,10 +138,6 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -138,10 +138,6 @@ std::pair<IdArray, NDArray> SparsePush(
IdArray in_idx, IdArray in_idx,
NDArray in_value, NDArray in_value,
NDArrayPartitionRef part) { NDArrayPartitionRef part) {
CHECK_EQ(in_idx->shape[0], in_value->shape[0]) <<
"Leading dimension of indices (" << in_idx->shape[0] << ") must match "
"leading dimension of values (" << in_value->shape[0] << ").";
const auto& ctx = in_idx->ctx; const auto& ctx = in_idx->ctx;
CHECK_EQ(ctx, in_value->ctx) << "Indices and values must be on the same " CHECK_EQ(ctx, in_value->ctx) << "Indices and values must be on the same "
"device"; "device";
...@@ -150,9 +146,15 @@ std::pair<IdArray, NDArray> SparsePush( ...@@ -150,9 +146,15 @@ std::pair<IdArray, NDArray> SparsePush(
// TODO(dlasalle): Get the stream from the device context. // TODO(dlasalle): Get the stream from the device context.
cudaStream_t stream = 0; cudaStream_t stream = 0;
CHECK_EQ(in_idx->ndim, 1) << "Indices must be 1-dimensional"; CHECK_LE(in_idx->ndim, 1) << "The tensor of sending indices must be of "
"dimension one (or empty).";
const int64_t num_in = in_idx->ndim > 0 ? in_idx->shape[0] : 0;
CHECK_EQ(num_in, in_value->ndim > 0 ? in_value->shape[0] : 0) <<
"Leading dimension of indices (" << num_in << ") must match "
"leading dimension of values (" <<
(in_value->ndim > 0 ? in_value->shape[0] : 0) << ").";
const int64_t num_in = in_idx->shape[0];
int64_t num_feat = 1; int64_t num_feat = 1;
for (int d = 1; d < in_value->ndim; ++d) { for (int d = 1; d < in_value->ndim; ++d) {
num_feat *= in_value->shape[d]; num_feat *= in_value->shape[d];
...@@ -297,10 +299,9 @@ NDArray SparsePull( ...@@ -297,10 +299,9 @@ NDArray SparsePull(
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
CHECK_EQ(req_idx->ndim, 1) << "The tensor of requested indices must be of " CHECK_LE(req_idx->ndim, 1) << "The tensor of requested indices must be of "
"dimension one."; "dimension one (or empty).";
const int64_t num_in = req_idx->ndim > 0 ? req_idx->shape[0] : 0;
const int64_t num_in = req_idx->shape[0];
int64_t num_feat = 1; int64_t num_feat = 1;
for (int d = 1; d < local_tensor->ndim; ++d) { for (int d = 1; d < local_tensor->ndim; ++d) {
num_feat *= local_tensor->shape[d]; num_feat *= local_tensor->shape[d];
...@@ -328,7 +329,7 @@ NDArray SparsePull( ...@@ -328,7 +329,7 @@ NDArray SparsePull(
static_cast<const int64_t*>(part_perm.second->data); static_cast<const int64_t*>(part_perm.second->data);
// permute requests // permute requests
{ if (num_in > 0) {
const dim3 block(256); const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x); const dim3 grid((num_in+block.x-1)/block.x);
......
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "../../src/partition/ndarray_partition.h" #include "../../src/partition/ndarray_partition.h"
#include "./common.h"
using namespace dgl; using namespace dgl;
...@@ -7,8 +8,7 @@ using namespace dgl::partition; ...@@ -7,8 +8,7 @@ using namespace dgl::partition;
template<DLDeviceType XPU, typename IdType> template<DLDeviceType XPU, typename IdType>
void _TestRemainder() void _TestRemainder_GeneratePermutation() {
{
const int64_t size = 160000; const int64_t size = 160000;
const int num_parts = 7; const int num_parts = 7;
NDArrayPartitionRef part = CreatePartitionRemainderBased( NDArrayPartitionRef part = CreatePartitionRemainderBased(
...@@ -48,10 +48,43 @@ void _TestRemainder() ...@@ -48,10 +48,43 @@ void _TestRemainder()
} }
} }
template<DLDeviceType XPU, typename IdType>
void _TestRemainder_MapToX() {
const int64_t size = 160000;
const int num_parts = 7;
NDArrayPartitionRef part = CreatePartitionRemainderBased(
size, num_parts);
for (int part_id = 0; part_id < num_parts; ++part_id) {
IdArray local = aten::Range(0, part->PartSize(part_id), sizeof(IdType)*8,
DGLContext{XPU, 0});
IdArray global = part->MapToGlobal(local, part_id);
IdArray act_local = part->MapToLocal(global).CopyTo(CPU);
// every global index should have the same remainder as the part id
ASSERT_EQ(global->shape[0], local->shape[0]);
global = global.CopyTo(CPU);
for (size_t i = 0; i < global->shape[0]; ++i) {
EXPECT_EQ(Ptr<IdType>(global)[i] % num_parts, part_id) << "i=" << i <<
", num_parts=" << num_parts << ", part_id=" << part_id;
}
// the remapped local indices to should match the original
local = local.CopyTo(CPU);
ASSERT_EQ(local->shape[0], act_local->shape[0]);
for (size_t i = 0; i < act_local->shape[0]; ++i) {
EXPECT_EQ(Ptr<IdType>(local)[i], Ptr<IdType>(act_local)[i]);
}
}
}
TEST(PartitionTest, TestRemainderPartition) { TEST(PartitionTest, TestRemainderPartition) {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
_TestRemainder<kDLGPU, int32_t>(); _TestRemainder_GeneratePermutation<kDLGPU, int32_t>();
_TestRemainder<kDLGPU, int64_t>(); _TestRemainder_GeneratePermutation<kDLGPU, int64_t>();
_TestRemainder_MapToX<kDLGPU, int32_t>();
_TestRemainder_MapToX<kDLGPU, int64_t>();
#endif #endif
// CPU is not implemented // CPU is not implemented
......
...@@ -144,6 +144,9 @@ def start_sparse_adam_worker(rank, world_size, has_zero_grad=False, num_embs=128 ...@@ -144,6 +144,9 @@ def start_sparse_adam_worker(rank, world_size, has_zero_grad=False, num_embs=128
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("num_workers", [2, 4, 8]) @pytest.mark.parametrize("num_workers", [2, 4, 8])
def test_multiprocess_sparse_adam(num_workers): def test_multiprocess_sparse_adam(num_workers):
if F.ctx().type == 'cuda' and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")
worker_list = [] worker_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -159,6 +162,9 @@ def test_multiprocess_sparse_adam(num_workers): ...@@ -159,6 +162,9 @@ def test_multiprocess_sparse_adam(num_workers):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("num_workers", [2, 4, 8]) @pytest.mark.parametrize("num_workers", [2, 4, 8])
def test_multiprocess_sparse_adam_zero_step(num_workers): def test_multiprocess_sparse_adam_zero_step(num_workers):
if F.ctx().type == 'cuda' and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")
worker_list = [] worker_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment