Unverified Commit 6fb16100 authored by Sean Naren's avatar Sean Naren Committed by GitHub
Browse files

Replace timer print rank 0 with logging (#732)



* Use log dist function instead of print

* Expose ranks
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 248f6383
...@@ -4,6 +4,7 @@ Copyright 2019 The Microsoft DeepSpeed Team ...@@ -4,6 +4,7 @@ Copyright 2019 The Microsoft DeepSpeed Team
import time import time
import torch import torch
from deepspeed.utils.logging import log_dist
from deepspeed.utils import logger from deepspeed.utils import logger
...@@ -15,14 +16,6 @@ except ImportError: ...@@ -15,14 +16,6 @@ except ImportError:
pass pass
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message)
else:
print(message)
class SynchronizedWallClockTimer: class SynchronizedWallClockTimer:
"""Group of timers. Borrowed from Nvidia Megatron code""" """Group of timers. Borrowed from Nvidia Megatron code"""
class Timer: class Timer:
...@@ -88,7 +81,7 @@ class SynchronizedWallClockTimer: ...@@ -88,7 +81,7 @@ class SynchronizedWallClockTimer:
torch.cuda.max_memory_cached() / (1024 * 1024 * 1024)) torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))
return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache) return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache)
def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False): def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None):
"""Log a group of timers.""" """Log a group of timers."""
assert normalizer > 0.0 assert normalizer > 0.0
string = f'rank={torch.distributed.get_rank()} time (ms)' string = f'rank={torch.distributed.get_rank()} time (ms)'
...@@ -98,9 +91,7 @@ class SynchronizedWallClockTimer: ...@@ -98,9 +91,7 @@ class SynchronizedWallClockTimer:
reset=reset) * 1000.0 / normalizer reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time) string += ' | {}: {:.2f}'.format(name, elapsed_time)
# TODO: use our logging utilitied to selectively print. Useful for model log_dist(string, ranks=ranks or [0])
# parallelism because rank=0 is too restrictive.
print_rank_0(string)
class ThroughputTimer(): class ThroughputTimer():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment