Unverified Commit 1bc77c5e authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[examples]educe memory consumption (#4558)

* [examples]educe memory consumption

* reffine help message

* refine
parent 9a00cf19
...@@ -185,8 +185,6 @@ def run(args, device, data): ...@@ -185,8 +185,6 @@ def run(args, device, data):
loss_fcn = loss_fcn.to(device) loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
train_size = th.sum(g.ndata['train_mask'][0:g.number_of_nodes()])
# Training loop # Training loop
iter_tput = [] iter_tput = []
epoch = 0 epoch = 0
...@@ -284,13 +282,17 @@ def main(args): ...@@ -284,13 +282,17 @@ def main(args):
g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)), g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)), len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid)))) len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid))))
del local_nid
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device('cpu')
else: else:
dev_id = g.rank() % args.num_gpus dev_id = g.rank() % args.num_gpus
device = th.device('cuda:'+str(dev_id)) device = th.device('cuda:'+str(dev_id))
n_classes = args.n_classes
if n_classes == -1:
labels = g.ndata['labels'][np.arange(g.number_of_nodes())] labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
del labels
print('#labels:', n_classes) print('#labels:', n_classes)
# Pack data # Pack data
...@@ -307,8 +309,12 @@ if __name__ == '__main__': ...@@ -307,8 +309,12 @@ if __name__ == '__main__':
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--part_config', type=str, help='The path to the partition config file') parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num_clients', type=int, help='The number of clients') parser.add_argument('--num_clients', type=int, help='The number of clients')
parser.add_argument('--n_classes', type=int, help='the number of classes') parser.add_argument('--n_classes', type=int, default=-1,
parser.add_argument('--backend', type=str, default='gloo', help='pytorch distributed backend') help='The number of classes. If not specified, this'
' value will be calculated via scaning all the labels'
' in the dataset which probably causes memory burst.')
parser.add_argument('--backend', type=str, default='gloo',
help='pytorch distributed backend')
parser.add_argument('--num_gpus', type=int, default=-1, parser.add_argument('--num_gpus', type=int, default=-1,
help="the number of GPU device. Use -1 for CPU training") help="the number of GPU device. Use -1 for CPU training")
parser.add_argument('--num_epochs', type=int, default=20) parser.add_argument('--num_epochs', type=int, default=20)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment