Commit 77efcccb authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Timing levels

parent 83d78670
......@@ -411,6 +411,32 @@ def _add_logging_args(parser):
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--timing-log-level', type=int,
default=0, choices=range(0,3),
help='Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.')
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
help='If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.',
dest='barrier_with_L1_time')
group.add_argument('--timing-log-option', type=str, default='minmax',
choices=['max', 'minmax', 'all'],
help='Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
......
......@@ -17,7 +17,6 @@
import os
import sys
import time
from functools import reduce
import operator
import torch
......@@ -25,6 +24,7 @@ import torch
from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer
from .microbatches import build_num_microbatches_calculator
from .timers import Timers
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
......@@ -108,7 +108,7 @@ def set_global_variables(args):
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
_set_timers(args)
_set_global_memory_buffer()
if args.exit_signal_handler:
......@@ -182,11 +182,12 @@ def _set_adlr_autoresume(args):
_GLOBAL_ADLR_AUTORESUME = AutoResume
def _set_timers():
def _set_timers(args):
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
def _set_global_memory_buffer():
"""Initialize global buffer"""
......@@ -205,87 +206,6 @@ def _ensure_var_is_not_initialized(var, name):
assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
......
......@@ -532,17 +532,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"""
# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
timers('embedding-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
timers('embedding-grads-all-reduce').stop()
# Reduce-scatter setup.
timers('backward-params-all-reduce').start()
timers('grads-reduce-scatter', log_level=1).start(
barrier=args.barrier_with_L1_time)
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
......@@ -563,7 +566,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group = data_parallel_group,
)
timers('backward-params-all-reduce').stop()
timers('grads-reduce-scatter').stop()
def gather_model_params(self, args, timers):
......@@ -575,7 +578,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
can be copied from param.main_grad to param.
"""
timers('backward-params-all-gather').start()
timers('params-all-gather', log_level=1).start(
barrier=args.barrier_with_L1_time)
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
......@@ -602,7 +606,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for param in param_map:
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
timers('params-all-gather').stop()
def _collect_main_grad_data_for_unscaling(self):
......
......@@ -294,21 +294,24 @@ class MegatronOptimizer(ABC):
"""All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
timers('layernorm-grads-all-reduce').stop()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
timers('grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
for model in self.models:
model.allreduce_gradients()
timers('backward-params-all-reduce').stop()
timers('grads-all-reduce').stop()
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
timers('embedding-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
timers('embedding-grads-all-reduce').stop()
class MixedPrecisionOptimizer(MegatronOptimizer):
......@@ -416,7 +419,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def step(self, args, timers):
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
......@@ -425,7 +429,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
if self.grad_scaler:
# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
timers('optimizer-unscale-and-check-inf', log_level=1).start(
barrier=args.barrier_with_L1_time)
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
......@@ -438,25 +443,29 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# Count the zeros in the grads.
timers('optimizer-count-zeros').start()
timers('optimizer-count-zeros', log_level=1).start(
barrier=args.barrier_with_L1_time)
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop()
# Step the optimizer.
timers('optimizer-inner-step').start()
timers('optimizer-inner-step', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.optimizer.step()
timers('optimizer-inner-step').stop()
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=args.barrier_with_L1_time)
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()
......@@ -725,7 +734,8 @@ class FP32Optimizer(MegatronOptimizer):
Always return successful since there is no overflow."""
# Copy main_grads to grads.
timers('optimizer-copy-to-main-grad').start()
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
......@@ -739,20 +749,23 @@ class FP32Optimizer(MegatronOptimizer):
timers('optimizer-copy-to-main-grad').stop()
# Clip gradients.
timers('optimizer-clip-main-grad').start()
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
timers('optimizer-count-zeros').start()
timers('optimizer-count-zeros', log_level=1).start(
barrier=args.barrier_with_L1_time)
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop()
# Update parameters.
timers('optimizer-inner-step').start()
timers('optimizer-inner-step', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.optimizer.step()
timers('optimizer-inner-step').stop()
......
......@@ -163,7 +163,7 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
input_tensor = None
else:
if timers is not None:
timers('forward-recv').start()
timers('forward-recv', log_level=2).start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
......@@ -182,7 +182,7 @@ def recv_backward(tensor_shape=None, timers=None):
output_tensor_grad = None
else:
if timers is not None:
timers('backward-recv').start()
timers('backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
......@@ -199,7 +199,7 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
timers('forward-send', log_level=2).start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
......@@ -215,7 +215,7 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
timers('backward-send').start()
timers('backward-send', log_level=2).start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
......@@ -232,7 +232,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
output_tensor_grad = None
else:
if timers is not None:
timers('forward-send-backward-recv').start()
timers('forward-send-backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
......@@ -250,7 +250,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
input_tensor = None
else:
if timers is not None:
timers('backward-send-forward-recv').start()
timers('backward-send-forward-recv', log_level=2).start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
......@@ -265,7 +265,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
......@@ -280,7 +280,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
......@@ -297,7 +297,8 @@ def send_forward_backward_recv_forward_backward(
recv_next, tensor_shape=None, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
timers('forward-backward-send-forward-backward-recv',
log_level=2).start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
......
......@@ -107,6 +107,7 @@ def forward_step(forward_step_func,
model,
input_tensor,
forward_data_store,
timers,
collect_non_loss_data=False):
"""Forward step for passed-in model.
......@@ -115,9 +116,9 @@ def forward_step(forward_step_func,
Returns output tensor."""
args = get_args()
timers = get_timers()
timers('forward-compute').start()
if timers is not None:
timers('forward-compute', log_level=2).start()
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
......@@ -138,7 +139,8 @@ def forward_step(forward_step_func,
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
timers('forward-compute').stop()
if timers is not None:
timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
......@@ -151,7 +153,8 @@ def forward_step(forward_step_func,
return [output_tensor]
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
def backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
......@@ -165,8 +168,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# connections.
args = get_args()
timers = get_timers()
timers('backward-compute').start()
if timers is not None:
timers('backward-compute', log_level=2).start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
......@@ -207,7 +210,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
timers('backward-compute').stop()
if timers is not None:
timers('backward-compute').stop()
return input_tensor_grad
......@@ -243,18 +247,19 @@ def forward_backward_no_pipelining(forward_step_func,
for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
timers, collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
timers, output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
timers, collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)
return forward_data_store
......@@ -269,6 +274,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
forward_data_store = []
......@@ -278,7 +286,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args()
if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
......@@ -337,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
model[model_chunk_id],
input_tensor,
forward_data_store,
timers,
collect_non_loss_data)
output_tensors[model_chunk_id].append(output_tensor)
......@@ -364,7 +372,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
backward_step(optimizer,
input_tensor,
output_tensor,
output_tensor_grad)
output_tensor_grad,
timers)
return input_tensor_grad
......@@ -620,8 +629,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
timers = get_timers()
assert len(model) == 1
model = model[0]
......@@ -656,7 +664,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
collect_non_loss_data)
timers, collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
......@@ -676,7 +684,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
collect_non_loss_data)
timers, collect_non_loss_data)
if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers)
......@@ -701,7 +709,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
output_tensor_grad, timers)
if last_iteration:
input_tensor = None
......@@ -721,7 +729,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
output_tensor_grad, timers)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron timers."""
from abc import ABC
from abc import abstractmethod
import time
import torch
class TimerBase(ABC):
def __init__(self, name):
self.name = name
@abstractmethod
def start(self, barrier=False):
pass
@abstractmethod
def stop(self, barrier=False):
pass
@abstractmethod
def reset(self):
pass
@abstractmethod
def elapsed(self, reset=True, barrier=False):
pass
class DummyTimer(TimerBase):
def __init__(self):
super().__init__('dummy timer')
def start(self, barrier=False):
return
def stop(self, barrier=False):
return
def reset(self):
return
def elapsed(self, reset=True, barrier=False):
raise Exception('dummy timer should not be used to '
'calculate elapsed time')
class Timer(TimerBase):
"""
Comment on using `barrier`: If this flag is passed, then all
the caller processes will wait till all reach the timing routine.
It is up to the user to make sure all the ranks in `barrier_group`
call it otherwise, it will result in a hang.
Comment on `barrier_group`: By default it is set to None which
in torch distributed land, it will result in the global communicator.
"""
def __init__(self, name):
super().__init__(name)
self._elapsed = 0.0
self._started = False
# Note that None will default to the global process group
self._barrier_group = None
self._start_time = time.time()
def set_barrier_group(self, barrier_group):
self._barrier_group = barrier_group
def start(self, barrier=False):
"""Start the timer."""
assert not self._started, 'timer has already been started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
self._start_time = time.time()
self._started = True
def stop(self, barrier=False):
"""Stop the timer."""
assert self._started, 'timer is not started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
self._elapsed += (time.time() - self._start_time)
self._started = False
def reset(self):
"""Reset timer."""
self._elapsed = 0.0
self._started = False
def elapsed(self, reset=True, barrier=False):
"""Calculate the elapsed time."""
_started = self._started
# If the timing in progress, end it first.
if self._started:
self.stop(barrier=barrier)
# Get the elapsed time.
_elapsed = self._elapsed
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if _started:
self.start(barrier=barrier)
return _elapsed
class Timers:
"""Group of timers."""
def __init__(self, log_level, log_option):
self._log_level = log_level
self._log_option = log_option
self._timers = {}
self._log_levels = {}
self._dummy_timer = DummyTimer()
self._max_log_level = 2
def __call__(self, name, log_level=None):
# If the timer has already been set, then check if the log-level
# is provided, it matches the one that the timer was created with.
if name in self._timers:
if log_level is not None:
assert log_level == self._log_levels[name], \
'input log level {} does not match already existing '\
'log level {} for {} timer'.format(
log_level, self._log_levels[name], name)
return self._timers[name]
# If timer does not exist and no log level is provided,
# set it to the max log level which is 2.
if log_level is None:
log_level = self._max_log_level
assert log_level <= self._max_log_level, \
'log level {} is larger than max supported log level {}'.format(
log_level, self._max_log_level)
# Now if the input log level is larger than the one set for
# the timers class, just ignore it and return a dummy timer.
if log_level > self._log_level:
return self._dummy_timer
# Otherwise, initalize the timer and set the level.
self._timers[name] = Timer(name)
self._log_levels[name] = log_level
return self._timers[name]
def _get_elapsed_time_all_ranks(self, names, reset, barrier):
"""
Assumptions:
- All the ranks call this function.
- `names` are identical on all ranks.
If the above assumptions are not met, calling this function will
result in hang.
Arguments:
- names: list of timer names
- reset: reset the timer after recording the elapsed time
- barrier: if set, do a global barrier before time measurments
"""
# First make sure all the callers are in sync.
if barrier:
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
# Here we can use gather on the rank we want to print the
# timing, however, there is no gather_base support in
# pytorch yet. It is simpler to deal with a single tensor
# and since we are only gathering a small amount of data,
# it should be ok to use all-gather instead of gather.
rank_name_to_time = torch.zeros((world_size, len(names)),
dtype=torch.float,
device=torch.cuda.current_device())
for i, name in enumerate(names):
if name in self._timers:
# Here we don't need to pass the barrier flag as all
# the processes are already in sync. This avoids the
# issue of different timers having different barrier
# groups inside their class.
rank_name_to_time[rank, i] = self._timers[name].elapsed(
reset=reset)
# See the note above for why we are not using gather.
torch.distributed._all_gather_base(rank_name_to_time.view(-1),
rank_name_to_time[rank, :].view(-1))
return rank_name_to_time
def _get_global_min_max_time(self, names, reset, barrier, normalizer):
"""Report only min and max times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
barrier)
name_to_min_max_time = {}
for i, name in enumerate(names):
rank_to_time = rank_name_to_time[:, i]
# filter out the ones we did not have any timings for
rank_to_time = rank_to_time[rank_to_time > 0.0]
# If the timer exists:
if rank_to_time.numel() > 0:
name_to_min_max_time[name] = (
rank_to_time.min().item() / normalizer,
rank_to_time.max().item() / normalizer)
return name_to_min_max_time
def _get_global_min_max_time_string(self, names, reset, barrier,
normalizer, max_only):
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
if not name_to_min_max_time:
return None
output_string = '(min, max) time across ranks (ms):'
for name in name_to_min_max_time:
min_time, max_time = name_to_min_max_time[name]
if max_only:
output_string += '\n {}: {:.2f}'.format(
(name+' ').ljust(48, '.'), max_time)
else:
output_string += '\n {}: ({:.2f}, {:.2f})'.format(
(name+' ').ljust(48, '.'), min_time, max_time)
return output_string
def _get_all_ranks_time_string(self, names, reset, barrier, normalizer):
"""Report times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
barrier)
output_string = 'times across ranks (ms):'
no_reported_timing = True
for i, name in enumerate(names):
not_yet_found = True
for rank in range(torch.distributed.get_world_size()):
if rank_name_to_time[rank, i] > 0:
no_reported_timing = False
if not_yet_found:
not_yet_found = False
output_string += '\n {}:'.format(name)
output_string += '\n rank {:2d}: {:.2f}'.format(
rank, rank_name_to_time[rank, i] / normalizer)
if no_reported_timing:
return None
return output_string
def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False):
"""Log a group of timers."""
# Print.
assert normalizer > 0.0
if self._log_option in ['max', 'minmax']:
max_only = False
if self._log_option == 'max':
max_only = True
output_string = self._get_global_min_max_time_string(
names, reset, barrier, normalizer/1000.0, max_only)
elif self._log_option == 'all':
output_string = self._get_all_ranks_time_string(names,
reset, barrier,
normalizer/1000.0)
else:
raise Exception('unknown timing log option {}'.format(
self._log_option))
# If no input rank is provided, log on last rank.
if rank is None:
rank = torch.distributed.get_world_size() - 1
if rank == torch.distributed.get_rank() and output_string is not None:
print(output_string, flush=True)
def write(self, names, writer, iteration, normalizer=1.0,
reset=False, barrier=False):
"""Write timers to a tensorboard writer
Note that we only report maximum time across ranks to tensorboard.
"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
if writer is not None:
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
writer.add_scalar(name + '-time', max_time, iteration)
......@@ -119,23 +119,28 @@ def pretrain(train_valid_test_dataset_provider,
timers = get_timers()
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
model_type)
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
# Data stuff.
timers('train/valid/test-data-iterators-setup').start()
timers('train/valid/test-data-iterators-setup', log_level=0).start(
barrier=True)
if args.virtual_pipeline_model_parallel_size is not None:
all_data_iterators = [
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
for _ in range(len(model))
]
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
train_data_iterator = [data_iterators[0]
for data_iterators in all_data_iterators]
valid_data_iterator = [data_iterators[1]
for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2]
for data_iterators in all_data_iterators]
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
......@@ -145,7 +150,8 @@ def pretrain(train_valid_test_dataset_provider,
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
timers.log(['model-and-optimizer-setup',
'train/valid/test-data-iterators-setup'], barrier=True)
print_rank_0('training ...')
iteration = 0
......@@ -373,13 +379,9 @@ def setup_model_and_optimizer(model_provider_func,
if args.load is not None:
timers = get_timers()
# Extra barrier is added to make sure all ranks report the
# max time.
torch.distributed.barrier()
timers('load-checkpoint').start()
timers('load-checkpoint', log_level=0).start(barrier=True)
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers('load-checkpoint').stop(barrier=True)
timers.log(['load-checkpoint'])
else:
args.iteration = 0
......@@ -412,19 +414,21 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad()
# Forward pass.
timers('forward-backward', log_level=1).start(
barrier=args.barrier_with_L1_time)
forward_backward_func = get_forward_backward_func()
fwd_bwd_timers = timers if args.timing_log_level > 1 else None
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
optimizer, fwd_bwd_timers, forward_only=False)
timers('forward-backward').stop()
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# Reduce gradients.
timers('backward-reduce-model-grads').start()
optimizer.reduce_model_grads(args, timers)
timers('backward-reduce-model-grads').stop()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
......@@ -433,15 +437,13 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer').start()
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop()
# Gather params.
if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
......@@ -511,33 +513,32 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
nan_iters_key, 0) + int(got_nan)
# Logging.
timers_to_log = []
def add_to_logging(name):
if name in timers.timers:
timers_to_log.append(name)
add_to_logging('forward-compute')
add_to_logging('forward-recv')
add_to_logging('forward-send')
add_to_logging('forward-backward-send-forward-backward-recv')
add_to_logging('backward-compute')
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-send-backward-recv')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-layernorm-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-reduce-model-grads')
add_to_logging('backward-gather-model-params')
add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-count-zeros')
add_to_logging('optimizer-inner-step')
add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer')
add_to_logging('batch-generator')
timers_to_log = [
'forward-backward',
'forward-compute',
'backward-compute',
'batch-generator',
'forward-recv',
'forward-send',
'backward-recv',
'backward-send',
'forward-send-forward-recv',
'forward-send-backward-recv',
'backward-send-forward-recv',
'backward-send-backward-recv',
'forward-backward-send-forward-backward-recv',
'layernorm-grads-all-reduce',
'embedding-grads-all-reduce',
'grads-all-reduce',
'grads-reduce-scatter',
'params-all-gather',
'optimizer-copy-to-main-grad',
'optimizer-unscale-and-check-inf',
'optimizer-clip-main-grad',
'optimizer-count-zeros',
'optimizer-inner-step',
'optimizer-copy-main-to-model-params',
'optimizer']
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
......@@ -547,8 +548,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict[skipped_iters_key]
# Tensorboard values.
if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
is_last_rank():
# Timer requires all the ranks to call.
if args.log_timers_to_tensorboard and \
(iteration % args.tensorboard_log_interval == 0):
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
......@@ -581,9 +586,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
......@@ -603,7 +605,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed()
elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer:
if args.log_timers_to_tensorboard:
......@@ -653,11 +655,9 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save-checkpoint').start()
timers('save-checkpoint', log_level=0).start(barrier=True)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
timers('save-checkpoint').stop()
timers('save-checkpoint').stop(barrier=True)
timers.log(['save-checkpoint'])
......@@ -681,7 +681,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations.
iteration = args.iteration
timers('interval-time').start()
timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
......
......@@ -104,7 +104,7 @@ def forward_step(data_iterator, model):
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
timers('batch-generator', log_level=2).start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
data_iterator)
timers('batch-generator').stop()
......
......@@ -89,7 +89,7 @@ def forward_step(data_iterator, model):
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
......
......@@ -134,7 +134,7 @@ def forward_step(data_iterator, model):
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
timers('batch-generator', log_level=2).start()
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
......
......@@ -126,7 +126,7 @@ def forward_step(data_iterator, model):
timers = get_timers()
# Get the batch.
timers('batch generator').start()
timers('batch generator', log_level=2).start()
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \
= get_batch(data_iterator)
timers('batch generator').stop()
......
......@@ -77,7 +77,7 @@ def forward_step(data_iterator, model):
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
timers("batch-generator", log_level=2).start()
(
images,
labels,
......
......@@ -84,7 +84,7 @@ def forward_step(data_iterator, model):
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
timers("batch-generator", log_level=2).start()
(
images,
labels,
......
......@@ -91,7 +91,7 @@ def forward_step(data_iterator, model):
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
timers("batch-generator", log_level=2).start()
(
images,
masks,
......
......@@ -67,7 +67,7 @@ def _cross_entropy_forward_step(batch, model):
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
timers('batch-generator', log_level=2).start()
try:
batch_ = next(batch)
except BaseException:
......@@ -178,7 +178,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
report_memory_flag = True
# For each remaining epoch
timers('interval-time').start()
timers('interval-time', log_level=0).start(barrier=True)
for epoch in range(start_epoch, args.epochs):
print_rank_0('working on epoch {} ...'.format(epoch + 1))
......@@ -261,7 +261,7 @@ def finetune(train_valid_datasets_provider, model_provider,
'batch size scaling is not supported for finetuning'
# Train and validation data loaders.
timers('train/valid/test dataset/dataloder').start()
timers('train/valid/test dataset/dataloder', log_level=0).start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
......@@ -271,21 +271,21 @@ def finetune(train_valid_datasets_provider, model_provider,
timers('train/valid/test dataset/dataloder').stop()
# Build calback function.
timers('callback function').start()
timers('callback function', log_level=0).start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider()
timers('callback function').stop()
# Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start()
timers('model and optimizer', log_level=0).start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers('pretrained checkpoint').start()
timers('pretrained checkpoint', log_level=0).start(barrier=True)
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
......@@ -302,7 +302,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['train/valid/test dataset/dataloder', 'callback function',
'model and optimizer', 'pretrained checkpoint'])
'model and optimizer', 'pretrained checkpoint'], barrier=True)
print_rank_0('training ...')
# Finetune the model.
......
......@@ -63,7 +63,7 @@ def orqa(Dataset):
tokenizer = get_tokenizer()
# Get the batch.
timers('batch generator').start()
timers('batch generator', log_level=2).start()
try:
batch_ = next(batch)
except BaseException:
......
......@@ -68,7 +68,7 @@ def classification():
timers = get_timers()
# Get the batch.
timers("batch generator").start()
timers("batch generator", log_level=2).start()
try:
batch_ = next(batch)
except BaseException:
......
......@@ -136,7 +136,7 @@ def _train(
report_memory_flag = True
# For each remaining epoch
timers("interval-time").start()
timers("interval-time", log_level=0).start(barrier=True)
for epoch in range(start_epoch, args.epochs):
print_rank_0("working on epoch {} ...".format(epoch + 1))
......@@ -218,7 +218,7 @@ def finetune(
timers = get_timers()
# Train and validation data loaders.
timers("train/valid/test dataset/dataloder").start()
timers("train/valid/test dataset/dataloder", log_level=0).start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
......@@ -227,14 +227,14 @@ def finetune(
timers("train/valid/test dataset/dataloder").stop()
# Build calback function.
timers("callback function").start()
timers("callback function", log_level=0).start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider()
timers("callback function").stop()
# Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start()
timers("model and optimizer", log_level=0).start()
model, optimizer, opt_param_scheduler = \
setup_model_and_optimizer(
model_provider,
......@@ -246,7 +246,7 @@ def finetune(
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers("pretrained checkpoint").start()
timers("pretrained checkpoint", log_level=0).start(barrier=True)
if args.iteration == 0 and args.pretrained_checkpoint is not None:
if args.pretrained_checkpoint_type == 'default':
original_load = args.load
......
......@@ -123,7 +123,7 @@ def segmentation():
timers = get_timers()
# Get the batch.
timers("batch generator").start()
timers("batch generator", log_level=2).start()
import types
if isinstance(batch, types.GeneratorType):
batch_ = next(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment