Unverified Commit c53bf774 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Distributed Example] Add --close_profiler option to train_dist.py (#2199)

* update

* update
parent 0ced22c2
......@@ -187,13 +187,13 @@ def run(args, device, data):
# Training loop
iter_tput = []
profiler = Profiler()
profiler.start()
if args.close_profiler == False:
profiler.start()
epoch = 0
for epoch in range(args.num_epochs):
tic = time.time()
sample_time = 0
copy_time = 0
forward_time = 0
backward_time = 0
update_time = 0
......@@ -242,8 +242,8 @@ def run(args, device, data):
start = time.time()
toc = time.time()
print('Part {}, Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
g.rank(), toc - tic, sample_time, copy_time, forward_time, backward_time, update_time, num_seeds, num_inputs))
print('Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
g.rank(), toc - tic, sample_time, forward_time, backward_time, update_time, num_seeds, num_inputs))
epoch += 1
......@@ -253,9 +253,9 @@ def run(args, device, data):
g.ndata['labels'], val_nid, test_nid, args.batch_size_eval, device)
print('Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}'.format(g.rank(), val_acc, test_acc,
time.time() - start))
profiler.stop()
print(profiler.output_text(unicode=True, color=True))
if args.close_profiler == False:
profiler.stop()
print(profiler.output_text(unicode=True, color=True))
def main(args):
dgl.distributed.initialize(args.ip_config, args.num_servers, num_workers=args.num_workers)
......@@ -313,6 +313,7 @@ if __name__ == '__main__':
help="Number of sampling processes. Use 0 for no extra process.")
parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
parser.add_argument('--close_profiler', action='store_true', help='Close pyinstrument profiler')
args = parser.parse_args()
assert args.num_workers == int(os.environ.get('DGL_NUM_SAMPLER')), \
'The num_workers should be the same value with num_samplers.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment