Unverified Commit 4889c578 authored by xcwan's avatar xcwan Committed by GitHub
Browse files

Add 'nccl' backend in train_dist.py and fix pad_data function cuda bug (#3607)



* Add nccl backend  and fix pad_data function cuda bug

* Update train_dist.py

* Update train_dist.py
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 42897c36
...@@ -155,7 +155,7 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): ...@@ -155,7 +155,7 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
model.train() model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]) return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid])
def pad_data(nids): def pad_data(nids, device):
""" """
In distributed traning scenario, we need to make sure that each worker has same number of In distributed traning scenario, we need to make sure that each worker has same number of
batches. Otherwise the synchronization(barrier) is called diffirent times, which results in batches. Otherwise the synchronization(barrier) is called diffirent times, which results in
...@@ -165,7 +165,8 @@ def pad_data(nids): ...@@ -165,7 +165,8 @@ def pad_data(nids):
the maximum size among all workers. the maximum size among all workers.
""" """
import torch.distributed as dist import torch.distributed as dist
num_nodes = th.tensor(nids.numel()) # NCCL backend only supports GPU tensors, thus here we need to allocate it to gpu
num_nodes = th.tensor(nids.numel()).to(device)
dist.all_reduce(num_nodes, dist.ReduceOp.MAX) dist.all_reduce(num_nodes, dist.ReduceOp.MAX)
max_num_nodes = int(num_nodes) max_num_nodes = int(num_nodes)
nids_length = nids.shape[0] nids_length = nids.shape[0]
...@@ -183,7 +184,7 @@ def pad_data(nids): ...@@ -183,7 +184,7 @@ def pad_data(nids):
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, in_feats, n_classes, g = data train_nid, val_nid, test_nid, in_feats, n_classes, g = data
train_nid = pad_data(train_nid) train_nid = pad_data(train_nid, device)
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors, device) dgl.distributed.sample_neighbors, device)
...@@ -281,7 +282,7 @@ def run(args, device, data): ...@@ -281,7 +282,7 @@ def run(args, device, data):
def main(args): def main(args):
dgl.distributed.initialize(args.ip_config) dgl.distributed.initialize(args.ip_config)
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend=args.backend)
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print('rank:', g.rank())
...@@ -325,6 +326,7 @@ if __name__ == '__main__': ...@@ -325,6 +326,7 @@ if __name__ == '__main__':
parser.add_argument('--part_config', type=str, help='The path to the partition config file') parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num_clients', type=int, help='The number of clients') parser.add_argument('--num_clients', type=int, help='The number of clients')
parser.add_argument('--n_classes', type=int, help='the number of classes') parser.add_argument('--n_classes', type=int, help='the number of classes')
parser.add_argument('--backend', type=str, default='gloo', help='pytorch distributed backend')
parser.add_argument('--num_gpus', type=int, default=-1, parser.add_argument('--num_gpus', type=int, default=-1,
help="the number of GPU device. Use -1 for CPU training") help="the number of GPU device. Use -1 for CPU training")
parser.add_argument('--num_epochs', type=int, default=20) parser.add_argument('--num_epochs', type=int, default=20)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment