Unverified Commit d0638b1e authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Example] Update the example of distributed GNN training of RGCN (#2709)



* support gpu training.

* remove unnecessary arguments.

* update README.

* update time measurement.

* add zero_grad.
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-132.us-west-1.compute.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 668bd928
...@@ -120,7 +120,7 @@ python3 ~/workspace/dgl/tools/launch.py \ ...@@ -120,7 +120,7 @@ python3 ~/workspace/dgl/tools/launch.py \
--num_samplers 4 \ --num_samplers 4 \
--part_config data/ogbn-mag.json \ --part_config data/ogbn-mag.json \
--ip_config ip_config.txt \ --ip_config ip_config.txt \
"python3 entity_classify_dist.py --graph-name ogbn-mag --dataset ogbn-mag --fanout='25,25' --batch-size 512 --n-hidden 64 --lr 0.01 --eval-batch-size 16 --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt --num-workers 4 --num-servers 1 --sparse-embedding --sparse-lr 0.06 --node-feats" "python3 entity_classify_dist.py --graph-name ogbn-mag --dataset ogbn-mag --fanout='25,25' --batch-size 1024 --n-hidden 64 --lr 0.01 --eval-batch-size 1024 --low-mem --dropout 0.5 --use-self-loop --n-bases 2 --n-epochs 3 --layer-norm --ip-config ip_config.txt --num-workers 4 --num-servers 1 --sparse-embedding --sparse-lr 0.06 --num_gpus 1"
``` ```
We can get the performance score at the second epoch: We can get the performance score at the second epoch:
......
...@@ -371,12 +371,22 @@ def run(args, device, data): ...@@ -371,12 +371,22 @@ def run(args, device, data):
low_mem=args.low_mem, low_mem=args.low_mem,
layer_norm=args.layer_norm) layer_norm=args.layer_norm)
model = model.to(device) model = model.to(device)
if not args.standalone: if not args.standalone:
model = th.nn.parallel.DistributedDataParallel(model) if args.num_gpus == -1:
# If there are dense parameters in the embedding layer model = DistributedDataParallel(model)
# or we use Pytorch saprse embeddings. # If there are dense parameters in the embedding layer
if len(embed_layer.node_projs) > 0 or not args.dgl_sparse: # or we use Pytorch saprse embeddings.
embed_layer = DistributedDataParallel(embed_layer, device_ids=None, output_device=None) if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:
embed_layer = DistributedDataParallel(embed_layer)
else:
dev_id = g.rank() % args.num_gpus
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
# If there are dense parameters in the embedding layer
# or we use Pytorch saprse embeddings.
if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:
embed_layer = embed_layer.to(device)
embed_layer = DistributedDataParallel(embed_layer, device_ids=[dev_id], output_device=dev_id)
if args.sparse_embedding: if args.sparse_embedding:
if args.dgl_sparse and args.standalone: if args.dgl_sparse and args.standalone:
...@@ -391,14 +401,14 @@ def run(args, device, data): ...@@ -391,14 +401,14 @@ def run(args, device, data):
else: else:
emb_optimizer = th.optim.SparseAdam(list(embed_layer.module.node_embeds.parameters()), lr=args.sparse_lr) emb_optimizer = th.optim.SparseAdam(list(embed_layer.module.node_embeds.parameters()), lr=args.sparse_lr)
print('optimize Pytorch sparse embedding:', embed_layer.module.node_embeds) print('optimize Pytorch sparse embedding:', embed_layer.module.node_embeds)
dense_params = list(model.parameters()) dense_params = list(model.parameters())
if args.node_feats: if args.standalone:
if args.standalone: dense_params += list(embed_layer.node_projs.parameters())
dense_params += list(embed_layer.node_projs.parameters()) print('optimize dense projection:', embed_layer.node_projs)
print('optimize dense projection:', embed_layer.node_projs) else:
else: dense_params += list(embed_layer.module.node_projs.parameters())
dense_params += list(embed_layer.module.node_projs.parameters()) print('optimize dense projection:', embed_layer.module.node_projs)
print('optimize dense projection:', embed_layer.module.node_projs)
optimizer = th.optim.Adam(dense_params, lr=args.lr, weight_decay=args.l2norm) optimizer = th.optim.Adam(dense_params, lr=args.lr, weight_decay=args.l2norm)
else: else:
all_params = list(model.parameters()) + list(embed_layer.parameters()) all_params = list(model.parameters()) + list(embed_layer.parameters())
...@@ -439,7 +449,7 @@ def run(args, device, data): ...@@ -439,7 +449,7 @@ def run(args, device, data):
for block in blocks: for block in blocks:
gen_norm(block) gen_norm(block)
feats = embed_layer(blocks[0].srcdata[dgl.NID], blocks[0].srcdata[dgl.NTYPE]) feats = embed_layer(blocks[0].srcdata[dgl.NID], blocks[0].srcdata[dgl.NTYPE])
label = labels[seeds] label = labels[seeds].to(device)
copy_time = time.time() copy_time = time.time()
feat_copy_t.append(copy_time - tic_step) feat_copy_t.append(copy_time - tic_step)
...@@ -450,17 +460,17 @@ def run(args, device, data): ...@@ -450,17 +460,17 @@ def run(args, device, data):
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
if args.sparse_embedding and not args.dgl_sparse: if args.sparse_embedding:
emb_optimizer.zero_grad() emb_optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step()
if args.sparse_embedding:
emb_optimizer.step()
compute_end = time.time() compute_end = time.time()
forward_t.append(forward_end - copy_time) forward_t.append(forward_end - copy_time)
backward_t.append(compute_end - forward_end) backward_t.append(compute_end - forward_end)
# Aggregate gradients in multiple nodes. # Update model parameters
optimizer.step()
if args.sparse_embedding:
emb_optimizer.step()
update_t.append(time.time() - compute_end) update_t.append(time.time() - compute_end)
step_t = time.time() - start step_t = time.time() - start
step_time.append(step_t) step_time.append(step_t)
...@@ -504,7 +514,10 @@ def main(args): ...@@ -504,7 +514,10 @@ def main(args):
g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)), g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)), len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid)))) len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid))))
device = th.device('cpu') if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(g.rank() % args.num_gpus))
labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))] labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))]
all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze() all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze() all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
...@@ -524,8 +537,8 @@ if __name__ == '__main__': ...@@ -524,8 +537,8 @@ if __name__ == '__main__':
parser.add_argument('--num-servers', type=int, default=1, help='Server count on each machine.') parser.add_argument('--num-servers', type=int, default=1, help='Server count on each machine.')
# rgcn related # rgcn related
parser.add_argument("--gpu", type=str, default='0', parser.add_argument('--num_gpus', type=int, default=-1,
help="gpu") help="the number of GPU device. Use -1 for CPU training")
parser.add_argument("--dropout", type=float, default=0, parser.add_argument("--dropout", type=float, default=0,
help="dropout probability") help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=16, parser.add_argument("--n-hidden", type=int, default=16,
...@@ -561,14 +574,10 @@ if __name__ == '__main__': ...@@ -561,14 +574,10 @@ if __name__ == '__main__':
help="Number of workers for distributed dataloader.") help="Number of workers for distributed dataloader.")
parser.add_argument("--low-mem", default=False, action='store_true', parser.add_argument("--low-mem", default=False, action='store_true',
help="Whether use low mem RelGraphCov") help="Whether use low mem RelGraphCov")
parser.add_argument("--mix-cpu-gpu", default=False, action='store_true',
help="Whether store node embeddins in cpu")
parser.add_argument("--sparse-embedding", action='store_true', parser.add_argument("--sparse-embedding", action='store_true',
help='Use sparse embedding for node embeddings.') help='Use sparse embedding for node embeddings.')
parser.add_argument("--dgl-sparse", action='store_true', parser.add_argument("--dgl-sparse", action='store_true',
help='Whether to use DGL sparse embedding') help='Whether to use DGL sparse embedding')
parser.add_argument('--node-feats', default=False, action='store_true',
help='Whether use node features')
parser.add_argument('--layer-norm', default=False, action='store_true', parser.add_argument('--layer-norm', default=False, action='store_true',
help='Use layer norm') help='Use layer norm')
parser.add_argument('--local_rank', type=int, help='get rank of the process') parser.add_argument('--local_rank', type=int, help='get rank of the process')
......
...@@ -76,6 +76,11 @@ class DistEmbedding: ...@@ -76,6 +76,11 @@ class DistEmbedding:
self._trace.append((idx, emb)) self._trace.append((idx, emb))
return emb return emb
def reset_trace(self):
'''Reset the traced data.
'''
self._trace = []
class SparseAdagradUDF: class SparseAdagradUDF:
''' The UDF to update the embeddings with sparse Adagrad. ''' The UDF to update the embeddings with sparse Adagrad.
...@@ -151,6 +156,7 @@ class SparseAdagrad: ...@@ -151,6 +156,7 @@ class SparseAdagrad:
def __init__(self, params, lr): def __init__(self, params, lr):
self._params = params self._params = params
self._lr = lr self._lr = lr
self._clean_grad = False
# We need to register a state sum for each embedding in the kvstore. # We need to register a state sum for each embedding in the kvstore.
for emb in params: for emb in params:
assert isinstance(emb, DistEmbedding), 'SparseAdagrad only supports DistEmbeding' assert isinstance(emb, DistEmbedding), 'SparseAdagrad only supports DistEmbeding'
...@@ -185,5 +191,15 @@ class SparseAdagrad: ...@@ -185,5 +191,15 @@ class SparseAdagrad:
# after we push them. # after we push them.
grads = F.cat(grads, 0) grads = F.cat(grads, 0)
kvstore.push(name, idxs, grads) kvstore.push(name, idxs, grads)
# Clean up the old traces.
emb._trace = [] if self._clean_grad:
# clean gradient track
for emb in self._params:
emb.reset_trace()
self._clean_grad = False
def zero_grad(self):
"""clean grad cache
"""
self._clean_grad = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment