Unverified Commit e14860d9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Example][Bug] fix device index for dist train on GPUs (#4403)

parent cf4727a9
......@@ -573,7 +573,8 @@ def main(args):
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(args.local_rank))
dev_id = g.rank() % args.num_gpus
device = th.device('cuda:'+str(dev_id))
labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))]
all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment