"vscode:/vscode.git/clone" did not exist on "a82825dbc1673496a8c6bbafb38f9d3be62616b8"
Unverified Commit 4889c578 authored by xcwan's avatar xcwan Committed by GitHub
Browse files

Add 'nccl' backend in train_dist.py and fix pad_data function cuda bug (#3607)



* Add nccl backend  and fix pad_data function cuda bug

* Update train_dist.py

* Update train_dist.py
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 42897c36
......@@ -155,7 +155,7 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid])
def pad_data(nids):
def pad_data(nids, device):
"""
In distributed traning scenario, we need to make sure that each worker has same number of
batches. Otherwise the synchronization(barrier) is called diffirent times, which results in
......@@ -165,7 +165,8 @@ def pad_data(nids):
the maximum size among all workers.
"""
import torch.distributed as dist
num_nodes = th.tensor(nids.numel())
# NCCL backend only supports GPU tensors, thus here we need to allocate it to gpu
num_nodes = th.tensor(nids.numel()).to(device)
dist.all_reduce(num_nodes, dist.ReduceOp.MAX)
max_num_nodes = int(num_nodes)
nids_length = nids.shape[0]
......@@ -183,7 +184,7 @@ def pad_data(nids):
def run(args, device, data):
# Unpack data
train_nid, val_nid, test_nid, in_feats, n_classes, g = data
train_nid = pad_data(train_nid)
train_nid = pad_data(train_nid, device)
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors, device)
......@@ -281,7 +282,7 @@ def run(args, device, data):
def main(args):
dgl.distributed.initialize(args.ip_config)
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
th.distributed.init_process_group(backend=args.backend)
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print('rank:', g.rank())
......@@ -325,6 +326,7 @@ if __name__ == '__main__':
parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num_clients', type=int, help='The number of clients')
parser.add_argument('--n_classes', type=int, help='the number of classes')
parser.add_argument('--backend', type=str, default='gloo', help='pytorch distributed backend')
parser.add_argument('--num_gpus', type=int, default=-1,
help="the number of GPU device. Use -1 for CPU training")
parser.add_argument('--num_epochs', type=int, default=20)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment