"dgl_sparse/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "ff090f698e0fbb2382f7c8ec2ca20a2e4500efdd"
Unverified Commit e14860d9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Example][Bug] fix device index for dist train on GPUs (#4403)

parent cf4727a9
...@@ -573,7 +573,8 @@ def main(args): ...@@ -573,7 +573,8 @@ def main(args):
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device('cpu')
else: else:
device = th.device('cuda:'+str(args.local_rank)) dev_id = g.rank() % args.num_gpus
device = th.device('cuda:'+str(dev_id))
labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))] labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))]
all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze() all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze() all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment