Unverified Commit 21a40279 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Fix distributed device mapping problem. (#3313)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-2-66.ec2.internal>
parent 5a245104
......@@ -279,7 +279,7 @@ def main(args):
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(g.rank() % args.num_gpus))
device = th.device('cuda:'+str(args.local_rank))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
print('#labels:', n_classes)
......
......@@ -276,7 +276,7 @@ def main(args):
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(g.rank() % args.num_gpus))
device = th.device('cuda:'+str(args.local_rank))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
print('#labels:', n_classes)
......
......@@ -435,7 +435,7 @@ def main(args):
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(g.rank() % args.num_gpus))
device = th.device('cuda:'+str(args.local_rank))
# Pack data
in_feats = g.ndata['features'].shape[1]
......
......@@ -198,7 +198,7 @@ def main(args):
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(g.rank() % args.num_gpus))
device = th.device('cuda:'+str(args.local_rank))
# Pack data
global_train_nid = global_train_nid.squeeze()
......
......@@ -573,7 +573,7 @@ def main(args):
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(g.rank() % args.num_gpus))
device = th.device('cuda:'+str(args.local_rank))
labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))]
all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment