"...text-generation-inference.git" did not exist on "0e97af456af3102ed4f927f7b7e870ec976079ae"
Unverified Commit e14860d9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Example][Bug] fix device index for dist train on GPUs (#4403)

parent cf4727a9
...@@ -573,7 +573,8 @@ def main(args): ...@@ -573,7 +573,8 @@ def main(args):
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device('cpu')
else: else:
device = th.device('cuda:'+str(args.local_rank)) dev_id = g.rank() % args.num_gpus
device = th.device('cuda:'+str(dev_id))
labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))] labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))]
all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze() all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze() all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment