Commit 77efcccb authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Timing levels

parent 83d78670
...@@ -411,6 +411,32 @@ def _add_logging_args(parser): ...@@ -411,6 +411,32 @@ def _add_logging_args(parser):
help='If set, calculate and log parameters norm.') help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true', group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.') help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--timing-log-level', type=int,
default=0, choices=range(0,3),
help='Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.')
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
help='If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.',
dest='barrier_with_L1_time')
group.add_argument('--timing-log-option', type=str, default='minmax',
choices=['max', 'minmax', 'all'],
help='Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.')
group.add_argument('--tensorboard-log-interval', type=int, default=1, group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.') help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000, group.add_argument('--tensorboard-queue-size', type=int, default=1000,
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import os import os
import sys import sys
import time
from functools import reduce from functools import reduce
import operator import operator
import torch import torch
...@@ -25,6 +24,7 @@ import torch ...@@ -25,6 +24,7 @@ import torch
from megatron import dist_signal_handler from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
from .microbatches import build_num_microbatches_calculator from .microbatches import build_num_microbatches_calculator
from .timers import Timers
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
...@@ -108,7 +108,7 @@ def set_global_variables(args): ...@@ -108,7 +108,7 @@ def set_global_variables(args):
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers(args)
_set_global_memory_buffer() _set_global_memory_buffer()
if args.exit_signal_handler: if args.exit_signal_handler:
...@@ -182,11 +182,12 @@ def _set_adlr_autoresume(args): ...@@ -182,11 +182,12 @@ def _set_adlr_autoresume(args):
_GLOBAL_ADLR_AUTORESUME = AutoResume _GLOBAL_ADLR_AUTORESUME = AutoResume
def _set_timers(): def _set_timers(args):
"""Initialize timers.""" """Initialize timers."""
global _GLOBAL_TIMERS global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers() _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
def _set_global_memory_buffer(): def _set_global_memory_buffer():
"""Initialize global buffer""" """Initialize global buffer"""
...@@ -205,87 +206,6 @@ def _ensure_var_is_not_initialized(var, name): ...@@ -205,87 +206,6 @@ def _ensure_var_is_not_initialized(var, name):
assert var is None, '{} is already initialized.'.format(name) assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
class GlobalMemoryBuffer: class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations. """Global buffer to avoid dynamic memory allocations.
......
...@@ -532,17 +532,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -532,17 +532,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
""" """
# All-reduce layer-norm grads (for sequence parallelism). # All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start() timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_layernorm_grads(args) self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop() timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('embedding-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_embedding_grads(args) self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop() timers('embedding-grads-all-reduce').stop()
# Reduce-scatter setup. # Reduce-scatter setup.
timers('backward-params-all-reduce').start() timers('grads-reduce-scatter', log_level=1).start(
barrier=args.barrier_with_L1_time)
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
...@@ -563,7 +566,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -563,7 +566,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group = data_parallel_group, group = data_parallel_group,
) )
timers('backward-params-all-reduce').stop() timers('grads-reduce-scatter').stop()
def gather_model_params(self, args, timers): def gather_model_params(self, args, timers):
...@@ -575,7 +578,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -575,7 +578,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
can be copied from param.main_grad to param. can be copied from param.main_grad to param.
""" """
timers('backward-params-all-gather').start() timers('params-all-gather', log_level=1).start(
barrier=args.barrier_with_L1_time)
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
...@@ -602,7 +606,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -602,7 +606,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for param in param_map: for param in param_map:
param.detach().copy_(param.main_grad) param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop() timers('params-all-gather').stop()
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
......
...@@ -294,21 +294,24 @@ class MegatronOptimizer(ABC): ...@@ -294,21 +294,24 @@ class MegatronOptimizer(ABC):
"""All-reduce all grads, and all-reduce embeddings.""" """All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism). # All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start() timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_layernorm_grads(args) self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop() timers('layernorm-grads-all-reduce').stop()
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
for model in self.models: for model in self.models:
model.allreduce_gradients() model.allreduce_gradients()
timers('backward-params-all-reduce').stop() timers('grads-all-reduce').stop()
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('embedding-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_embedding_grads(args) self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop() timers('embedding-grads-all-reduce').stop()
class MixedPrecisionOptimizer(MegatronOptimizer): class MixedPrecisionOptimizer(MegatronOptimizer):
...@@ -416,7 +419,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -416,7 +419,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def step(self, args, timers): def step(self, args, timers):
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start() timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
self._copy_model_grads_to_main_grads() self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
...@@ -425,7 +429,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -425,7 +429,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
if self.grad_scaler: if self.grad_scaler:
# Unscale and check for inf/nan. # Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start() timers('optimizer-unscale-and-check-inf', log_level=1).start(
barrier=args.barrier_with_L1_time)
found_inf_flag = self._unscale_main_grads_and_check_for_nan() found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop() timers('optimizer-unscale-and-check-inf').stop()
...@@ -438,25 +443,29 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -438,25 +443,29 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return False, None, None return False, None, None
# Clip the main gradients. # Clip the main gradients.
timers('optimizer-clip-main-grad').start() timers('optimizer-clip-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
grad_norm = None grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad) grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop() timers('optimizer-clip-main-grad').stop()
# Count the zeros in the grads. # Count the zeros in the grads.
timers('optimizer-count-zeros').start() timers('optimizer-count-zeros', log_level=1).start(
barrier=args.barrier_with_L1_time)
num_zeros_in_grad = self.count_zeros() if \ num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop() timers('optimizer-count-zeros').stop()
# Step the optimizer. # Step the optimizer.
timers('optimizer-inner-step').start() timers('optimizer-inner-step', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.optimizer.step() self.optimizer.step()
timers('optimizer-inner-step').stop() timers('optimizer-inner-step').stop()
# Update params from main params. # Update params from main params.
timers('optimizer-copy-main-to-model-params').start() timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=args.barrier_with_L1_time)
self._copy_main_params_to_model_params() self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
...@@ -725,7 +734,8 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -725,7 +734,8 @@ class FP32Optimizer(MegatronOptimizer):
Always return successful since there is no overflow.""" Always return successful since there is no overflow."""
# Copy main_grads to grads. # Copy main_grads to grads.
timers('optimizer-copy-to-main-grad').start() timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
if self.params_have_main_grad: if self.params_have_main_grad:
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
for param in param_group['params']: for param in param_group['params']:
...@@ -739,20 +749,23 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -739,20 +749,23 @@ class FP32Optimizer(MegatronOptimizer):
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# Clip gradients. # Clip gradients.
timers('optimizer-clip-main-grad').start() timers('optimizer-clip-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
grad_norm = None grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad) grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop() timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads # count the zeros in the grads
timers('optimizer-count-zeros').start() timers('optimizer-count-zeros', log_level=1).start(
barrier=args.barrier_with_L1_time)
num_zeros_in_grad = self.count_zeros() if \ num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop() timers('optimizer-count-zeros').stop()
# Update parameters. # Update parameters.
timers('optimizer-inner-step').start() timers('optimizer-inner-step', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.optimizer.step() self.optimizer.step()
timers('optimizer-inner-step').stop() timers('optimizer-inner-step').stop()
......
...@@ -163,7 +163,7 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None): ...@@ -163,7 +163,7 @@ def recv_forward(tensor_shape=None, dtype_=None, timers=None):
input_tensor = None input_tensor = None
else: else:
if timers is not None: if timers is not None:
timers('forward-recv').start() timers('forward-recv', log_level=2).start()
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -182,7 +182,7 @@ def recv_backward(tensor_shape=None, timers=None): ...@@ -182,7 +182,7 @@ def recv_backward(tensor_shape=None, timers=None):
output_tensor_grad = None output_tensor_grad = None
else: else:
if timers is not None: if timers is not None:
timers('backward-recv').start() timers('backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -199,7 +199,7 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None): ...@@ -199,7 +199,7 @@ def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
if timers is not None: if timers is not None:
timers('forward-send').start() timers('forward-send', log_level=2).start()
_communicate( _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -215,7 +215,7 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None): ...@@ -215,7 +215,7 @@ def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
if timers is not None: if timers is not None:
timers('backward-send').start() timers('backward-send', log_level=2).start()
_communicate( _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
...@@ -232,7 +232,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None): ...@@ -232,7 +232,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
output_tensor_grad = None output_tensor_grad = None
else: else:
if timers is not None: if timers is not None:
timers('forward-send-backward-recv').start() timers('forward-send-backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -250,7 +250,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None ...@@ -250,7 +250,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
input_tensor = None input_tensor = None
else: else:
if timers is not None: if timers is not None:
timers('backward-send-forward-recv').start() timers('backward-send-forward-recv', log_level=2).start()
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
...@@ -265,7 +265,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None ...@@ -265,7 +265,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None): def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline.""" """Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').start() timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _ = _communicate( input_tensor, _ = _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -280,7 +280,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer ...@@ -280,7 +280,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timer
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None): def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline.""" """Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').start() timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad = _communicate( _, output_tensor_grad = _communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
...@@ -297,7 +297,8 @@ def send_forward_backward_recv_forward_backward( ...@@ -297,7 +297,8 @@ def send_forward_backward_recv_forward_backward(
recv_next, tensor_shape=None, timers=None): recv_next, tensor_shape=None, timers=None):
"""Batched send and recv with previous and next ranks in pipeline.""" """Batched send and recv with previous and next ranks in pipeline."""
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').start() timers('forward-backward-send-forward-backward-recv',
log_level=2).start()
input_tensor, output_tensor_grad = _communicate( input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
......
...@@ -107,6 +107,7 @@ def forward_step(forward_step_func, ...@@ -107,6 +107,7 @@ def forward_step(forward_step_func,
model, model,
input_tensor, input_tensor,
forward_data_store, forward_data_store,
timers,
collect_non_loss_data=False): collect_non_loss_data=False):
"""Forward step for passed-in model. """Forward step for passed-in model.
...@@ -115,9 +116,9 @@ def forward_step(forward_step_func, ...@@ -115,9 +116,9 @@ def forward_step(forward_step_func,
Returns output tensor.""" Returns output tensor."""
args = get_args() args = get_args()
timers = get_timers()
timers('forward-compute').start() if timers is not None:
timers('forward-compute', log_level=2).start()
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module)) model, (torchDDP, LocalDDP, Float16Module))
...@@ -138,7 +139,8 @@ def forward_step(forward_step_func, ...@@ -138,7 +139,8 @@ def forward_step(forward_step_func,
data = loss_func(output_tensor, non_loss_data=True) data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data) forward_data_store.append(data)
timers('forward-compute').stop() if timers is not None:
timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state # and in decoder stack, then send encoder_hidden_state
...@@ -151,7 +153,8 @@ def forward_step(forward_step_func, ...@@ -151,7 +153,8 @@ def forward_step(forward_step_func,
return [output_tensor] return [output_tensor]
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): def backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers):
"""Backward step through passed-in output tensor. """Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss If last stage, output_tensor_grad is None, otherwise gradient of loss
...@@ -165,8 +168,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -165,8 +168,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# connections. # connections.
args = get_args() args = get_args()
timers = get_timers() if timers is not None:
timers('backward-compute').start() timers('backward-compute', log_level=2).start()
# Retain the grad on the input_tensor. # Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False unwrap_input_tensor_grad = False
...@@ -207,7 +210,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -207,7 +210,8 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
if unwrap_input_tensor_grad: if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0] input_tensor_grad = input_tensor_grad[0]
timers('backward-compute').stop() if timers is not None:
timers('backward-compute').stop()
return input_tensor_grad return input_tensor_grad
...@@ -243,18 +247,19 @@ def forward_backward_no_pipelining(forward_step_func, ...@@ -243,18 +247,19 @@ def forward_backward_no_pipelining(forward_step_func,
for i in range(get_num_microbatches() - 1): for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store, model, input_tensor, forward_data_store,
collect_non_loss_data) timers, collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) timers, output_tensor_grad)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store, model, input_tensor, forward_data_store,
collect_non_loss_data) timers, collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)
return forward_data_store return forward_data_store
...@@ -269,6 +274,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -269,6 +274,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
communication between pipeline stages as needed. communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args()
input_tensors = [[] for _ in range(len(model))] input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))]
forward_data_store = [] forward_data_store = []
...@@ -278,7 +286,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -278,7 +286,6 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args()
if args.sequence_parallel: if args.sequence_parallel:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else: else:
...@@ -337,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -337,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
model[model_chunk_id], model[model_chunk_id],
input_tensor, input_tensor,
forward_data_store, forward_data_store,
timers,
collect_non_loss_data) collect_non_loss_data)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
...@@ -364,7 +372,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, ...@@ -364,7 +372,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
backward_step(optimizer, backward_step(optimizer,
input_tensor, input_tensor,
output_tensor, output_tensor,
output_tensor_grad) output_tensor_grad,
timers)
return input_tensor_grad return input_tensor_grad
...@@ -620,8 +629,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -620,8 +629,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
args = get_args() args = get_args()
timers = get_timers()
assert len(model) == 1 assert len(model) == 1
model = model[0] model = model[0]
...@@ -656,7 +664,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -656,7 +664,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store, input_tensor, forward_data_store,
collect_non_loss_data) timers, collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
...@@ -676,7 +684,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -676,7 +684,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store, input_tensor, forward_data_store,
collect_non_loss_data) timers, collect_non_loss_data)
if forward_only: if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
...@@ -701,7 +709,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -701,7 +709,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad, timers)
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
...@@ -721,7 +729,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -721,7 +729,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad, timers)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron timers."""
from abc import ABC
from abc import abstractmethod
import time
import torch
class TimerBase(ABC):
def __init__(self, name):
self.name = name
@abstractmethod
def start(self, barrier=False):
pass
@abstractmethod
def stop(self, barrier=False):
pass
@abstractmethod
def reset(self):
pass
@abstractmethod
def elapsed(self, reset=True, barrier=False):
pass
class DummyTimer(TimerBase):
def __init__(self):
super().__init__('dummy timer')
def start(self, barrier=False):
return
def stop(self, barrier=False):
return
def reset(self):
return
def elapsed(self, reset=True, barrier=False):
raise Exception('dummy timer should not be used to '
'calculate elapsed time')
class Timer(TimerBase):
"""
Comment on using `barrier`: If this flag is passed, then all
the caller processes will wait till all reach the timing routine.
It is up to the user to make sure all the ranks in `barrier_group`
call it otherwise, it will result in a hang.
Comment on `barrier_group`: By default it is set to None which
in torch distributed land, it will result in the global communicator.
"""
def __init__(self, name):
super().__init__(name)
self._elapsed = 0.0
self._started = False
# Note that None will default to the global process group
self._barrier_group = None
self._start_time = time.time()
def set_barrier_group(self, barrier_group):
self._barrier_group = barrier_group
def start(self, barrier=False):
"""Start the timer."""
assert not self._started, 'timer has already been started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
self._start_time = time.time()
self._started = True
def stop(self, barrier=False):
"""Stop the timer."""
assert self._started, 'timer is not started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
self._elapsed += (time.time() - self._start_time)
self._started = False
def reset(self):
"""Reset timer."""
self._elapsed = 0.0
self._started = False
def elapsed(self, reset=True, barrier=False):
"""Calculate the elapsed time."""
_started = self._started
# If the timing in progress, end it first.
if self._started:
self.stop(barrier=barrier)
# Get the elapsed time.
_elapsed = self._elapsed
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if _started:
self.start(barrier=barrier)
return _elapsed
class Timers:
"""Group of timers."""
def __init__(self, log_level, log_option):
self._log_level = log_level
self._log_option = log_option
self._timers = {}
self._log_levels = {}
self._dummy_timer = DummyTimer()
self._max_log_level = 2
def __call__(self, name, log_level=None):
# If the timer has already been set, then check if the log-level
# is provided, it matches the one that the timer was created with.
if name in self._timers:
if log_level is not None:
assert log_level == self._log_levels[name], \
'input log level {} does not match already existing '\
'log level {} for {} timer'.format(
log_level, self._log_levels[name], name)
return self._timers[name]
# If timer does not exist and no log level is provided,
# set it to the max log level which is 2.
if log_level is None:
log_level = self._max_log_level
assert log_level <= self._max_log_level, \
'log level {} is larger than max supported log level {}'.format(
log_level, self._max_log_level)
# Now if the input log level is larger than the one set for
# the timers class, just ignore it and return a dummy timer.
if log_level > self._log_level:
return self._dummy_timer
# Otherwise, initalize the timer and set the level.
self._timers[name] = Timer(name)
self._log_levels[name] = log_level
return self._timers[name]
def _get_elapsed_time_all_ranks(self, names, reset, barrier):
"""
Assumptions:
- All the ranks call this function.
- `names` are identical on all ranks.
If the above assumptions are not met, calling this function will
result in hang.
Arguments:
- names: list of timer names
- reset: reset the timer after recording the elapsed time
- barrier: if set, do a global barrier before time measurments
"""
# First make sure all the callers are in sync.
if barrier:
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
# Here we can use gather on the rank we want to print the
# timing, however, there is no gather_base support in
# pytorch yet. It is simpler to deal with a single tensor
# and since we are only gathering a small amount of data,
# it should be ok to use all-gather instead of gather.
rank_name_to_time = torch.zeros((world_size, len(names)),
dtype=torch.float,
device=torch.cuda.current_device())
for i, name in enumerate(names):
if name in self._timers:
# Here we don't need to pass the barrier flag as all
# the processes are already in sync. This avoids the
# issue of different timers having different barrier
# groups inside their class.
rank_name_to_time[rank, i] = self._timers[name].elapsed(
reset=reset)
# See the note above for why we are not using gather.
torch.distributed._all_gather_base(rank_name_to_time.view(-1),
rank_name_to_time[rank, :].view(-1))
return rank_name_to_time
def _get_global_min_max_time(self, names, reset, barrier, normalizer):
"""Report only min and max times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
barrier)
name_to_min_max_time = {}
for i, name in enumerate(names):
rank_to_time = rank_name_to_time[:, i]
# filter out the ones we did not have any timings for
rank_to_time = rank_to_time[rank_to_time > 0.0]
# If the timer exists:
if rank_to_time.numel() > 0:
name_to_min_max_time[name] = (
rank_to_time.min().item() / normalizer,
rank_to_time.max().item() / normalizer)
return name_to_min_max_time
def _get_global_min_max_time_string(self, names, reset, barrier,
normalizer, max_only):
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
if not name_to_min_max_time:
return None
output_string = '(min, max) time across ranks (ms):'
for name in name_to_min_max_time:
min_time, max_time = name_to_min_max_time[name]
if max_only:
output_string += '\n {}: {:.2f}'.format(
(name+' ').ljust(48, '.'), max_time)
else:
output_string += '\n {}: ({:.2f}, {:.2f})'.format(
(name+' ').ljust(48, '.'), min_time, max_time)
return output_string
def _get_all_ranks_time_string(self, names, reset, barrier, normalizer):
"""Report times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
barrier)
output_string = 'times across ranks (ms):'
no_reported_timing = True
for i, name in enumerate(names):
not_yet_found = True
for rank in range(torch.distributed.get_world_size()):
if rank_name_to_time[rank, i] > 0:
no_reported_timing = False
if not_yet_found:
not_yet_found = False
output_string += '\n {}:'.format(name)
output_string += '\n rank {:2d}: {:.2f}'.format(
rank, rank_name_to_time[rank, i] / normalizer)
if no_reported_timing:
return None
return output_string
def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False):
"""Log a group of timers."""
# Print.
assert normalizer > 0.0
if self._log_option in ['max', 'minmax']:
max_only = False
if self._log_option == 'max':
max_only = True
output_string = self._get_global_min_max_time_string(
names, reset, barrier, normalizer/1000.0, max_only)
elif self._log_option == 'all':
output_string = self._get_all_ranks_time_string(names,
reset, barrier,
normalizer/1000.0)
else:
raise Exception('unknown timing log option {}'.format(
self._log_option))
# If no input rank is provided, log on last rank.
if rank is None:
rank = torch.distributed.get_world_size() - 1
if rank == torch.distributed.get_rank() and output_string is not None:
print(output_string, flush=True)
def write(self, names, writer, iteration, normalizer=1.0,
reset=False, barrier=False):
"""Write timers to a tensorboard writer
Note that we only report maximum time across ranks to tensorboard.
"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
if writer is not None:
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
writer.add_scalar(name + '-time', max_time, iteration)
...@@ -119,23 +119,28 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -119,23 +119,28 @@ def pretrain(train_valid_test_dataset_provider,
timers = get_timers() timers = get_timers()
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start() timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_type) model_provider, model_type)
timers('model-and-optimizer-setup').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
'scheduler are built') 'scheduler are built')
# Data stuff. # Data stuff.
timers('train/valid/test-data-iterators-setup').start() timers('train/valid/test-data-iterators-setup', log_level=0).start(
barrier=True)
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
all_data_iterators = [ all_data_iterators = [
build_train_valid_test_data_iterators(train_valid_test_dataset_provider) build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
for _ in range(len(model)) for _ in range(len(model))
] ]
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators] train_data_iterator = [data_iterators[0]
valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators] for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators] valid_data_iterator = [data_iterators[1]
for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2]
for data_iterators in all_data_iterators]
else: else:
train_data_iterator, valid_data_iterator, test_data_iterator \ train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators( = build_train_valid_test_data_iterators(
...@@ -145,7 +150,8 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -145,7 +150,8 @@ def pretrain(train_valid_test_dataset_provider,
# Print setup timing. # Print setup timing.
print_rank_0('done with setup ...') print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup']) timers.log(['model-and-optimizer-setup',
'train/valid/test-data-iterators-setup'], barrier=True)
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
...@@ -373,13 +379,9 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -373,13 +379,9 @@ def setup_model_and_optimizer(model_provider_func,
if args.load is not None: if args.load is not None:
timers = get_timers() timers = get_timers()
# Extra barrier is added to make sure all ranks report the timers('load-checkpoint', log_level=0).start(barrier=True)
# max time.
torch.distributed.barrier()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier() timers('load-checkpoint').stop(barrier=True)
timers('load-checkpoint').stop()
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
else: else:
args.iteration = 0 args.iteration = 0
...@@ -412,19 +414,21 @@ def train_step(forward_step_func, data_iterator, ...@@ -412,19 +414,21 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad() optimizer.zero_grad()
# Forward pass. # Forward pass.
timers('forward-backward', log_level=1).start(
barrier=args.barrier_with_L1_time)
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
fwd_bwd_timers = timers if args.timing_log_level > 1 else None
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, fwd_bwd_timers, forward_only=False)
timers('forward-backward').stop()
# Empty unused memory. # Empty unused memory.
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Reduce gradients. # Reduce gradients.
timers('backward-reduce-model-grads').start()
optimizer.reduce_model_grads(args, timers) optimizer.reduce_model_grads(args, timers)
timers('backward-reduce-model-grads').stop()
# Vision gradients. # Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
...@@ -433,15 +437,13 @@ def train_step(forward_step_func, data_iterator, ...@@ -433,15 +437,13 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers) update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop() timers('optimizer').stop()
# Gather params. # Gather params.
if update_successful: if update_successful:
timers('backward-gather-model-params').start()
optimizer.gather_model_params(args, timers) optimizer.gather_model_params(args, timers)
timers('backward-gather-model-params').stop()
# Vision momentum. # Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
...@@ -511,33 +513,32 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -511,33 +513,32 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
nan_iters_key, 0) + int(got_nan) nan_iters_key, 0) + int(got_nan)
# Logging. # Logging.
timers_to_log = [] timers_to_log = [
'forward-backward',
def add_to_logging(name): 'forward-compute',
if name in timers.timers: 'backward-compute',
timers_to_log.append(name) 'batch-generator',
add_to_logging('forward-compute') 'forward-recv',
add_to_logging('forward-recv') 'forward-send',
add_to_logging('forward-send') 'backward-recv',
add_to_logging('forward-backward-send-forward-backward-recv') 'backward-send',
add_to_logging('backward-compute') 'forward-send-forward-recv',
add_to_logging('backward-recv') 'forward-send-backward-recv',
add_to_logging('backward-send') 'backward-send-forward-recv',
add_to_logging('backward-send-forward-recv') 'backward-send-backward-recv',
add_to_logging('backward-send-backward-recv') 'forward-backward-send-forward-backward-recv',
add_to_logging('backward-params-all-reduce') 'layernorm-grads-all-reduce',
add_to_logging('backward-layernorm-all-reduce') 'embedding-grads-all-reduce',
add_to_logging('backward-embedding-all-reduce') 'grads-all-reduce',
add_to_logging('backward-reduce-model-grads') 'grads-reduce-scatter',
add_to_logging('backward-gather-model-params') 'params-all-gather',
add_to_logging('optimizer-copy-to-main-grad') 'optimizer-copy-to-main-grad',
add_to_logging('optimizer-unscale-and-check-inf') 'optimizer-unscale-and-check-inf',
add_to_logging('optimizer-clip-main-grad') 'optimizer-clip-main-grad',
add_to_logging('optimizer-count-zeros') 'optimizer-count-zeros',
add_to_logging('optimizer-inner-step') 'optimizer-inner-step',
add_to_logging('optimizer-copy-main-to-model-params') 'optimizer-copy-main-to-model-params',
add_to_logging('optimizer') 'optimizer']
add_to_logging('batch-generator')
# Calculate batch size. # Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \ batch_size = args.micro_batch_size * args.data_parallel_size * \
...@@ -547,8 +548,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -547,8 +548,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict[skipped_iters_key] total_loss_dict[skipped_iters_key]
# Tensorboard values. # Tensorboard values.
if writer and (iteration % args.tensorboard_log_interval == 0 ) and \ # Timer requires all the ranks to call.
is_last_rank(): if args.log_timers_to_tensorboard and \
(iteration % args.tensorboard_log_interval == 0):
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if args.log_learning_rate_to_tensorboard: if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate', learning_rate, iteration) writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate, writer.add_scalar('learning-rate vs samples', learning_rate,
...@@ -581,9 +586,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -581,9 +586,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer.add_scalar('params-norm', params_norm, iteration) writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm, writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples) args.consumed_train_samples)
if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if args.log_memory_to_tensorboard: if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats() mem_stats = torch.cuda.memory_stats()
writer.add_scalar( writer.add_scalar(
...@@ -603,7 +605,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -603,7 +605,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
) )
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed() elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations elapsed_time_per_iteration = elapsed_time / total_iterations
if writer: if writer:
if args.log_timers_to_tensorboard: if args.log_timers_to_tensorboard:
...@@ -653,11 +655,9 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler): ...@@ -653,11 +655,9 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers() timers = get_timers()
# Extra barrier is added to make sure # Extra barrier is added to make sure
# all ranks report the max time. # all ranks report the max time.
torch.distributed.barrier() timers('save-checkpoint', log_level=0).start(barrier=True)
timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, opt_param_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() timers('save-checkpoint').stop(barrier=True)
timers('save-checkpoint').stop()
timers.log(['save-checkpoint']) timers.log(['save-checkpoint'])
...@@ -681,7 +681,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -681,7 +681,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
timers('interval-time').start() timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
......
...@@ -104,7 +104,7 @@ def forward_step(data_iterator, model): ...@@ -104,7 +104,7 @@ def forward_step(data_iterator, model):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator', log_level=2).start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch( tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
data_iterator) data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
......
...@@ -89,7 +89,7 @@ def forward_step(data_iterator, model): ...@@ -89,7 +89,7 @@ def forward_step(data_iterator, model):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator) data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
......
...@@ -134,7 +134,7 @@ def forward_step(data_iterator, model): ...@@ -134,7 +134,7 @@ def forward_step(data_iterator, model):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator', log_level=2).start()
query_tokens, query_mask, \ query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator) context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
......
...@@ -126,7 +126,7 @@ def forward_step(data_iterator, model): ...@@ -126,7 +126,7 @@ def forward_step(data_iterator, model):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator', log_level=2).start()
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \ tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
......
...@@ -77,7 +77,7 @@ def forward_step(data_iterator, model): ...@@ -77,7 +77,7 @@ def forward_step(data_iterator, model):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers("batch-generator").start() timers("batch-generator", log_level=2).start()
( (
images, images,
labels, labels,
......
...@@ -84,7 +84,7 @@ def forward_step(data_iterator, model): ...@@ -84,7 +84,7 @@ def forward_step(data_iterator, model):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers("batch-generator").start() timers("batch-generator", log_level=2).start()
( (
images, images,
labels, labels,
......
...@@ -91,7 +91,7 @@ def forward_step(data_iterator, model): ...@@ -91,7 +91,7 @@ def forward_step(data_iterator, model):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers("batch-generator").start() timers("batch-generator", log_level=2).start()
( (
images, images,
masks, masks,
......
...@@ -67,7 +67,7 @@ def _cross_entropy_forward_step(batch, model): ...@@ -67,7 +67,7 @@ def _cross_entropy_forward_step(batch, model):
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator', log_level=2).start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except BaseException: except BaseException:
...@@ -178,7 +178,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step, ...@@ -178,7 +178,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
report_memory_flag = True report_memory_flag = True
# For each remaining epoch # For each remaining epoch
timers('interval-time').start() timers('interval-time', log_level=0).start(barrier=True)
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
print_rank_0('working on epoch {} ...'.format(epoch + 1)) print_rank_0('working on epoch {} ...'.format(epoch + 1))
...@@ -261,7 +261,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -261,7 +261,7 @@ def finetune(train_valid_datasets_provider, model_provider,
'batch size scaling is not supported for finetuning' 'batch size scaling is not supported for finetuning'
# Train and validation data loaders. # Train and validation data loaders.
timers('train/valid/test dataset/dataloder').start() timers('train/valid/test dataset/dataloder', log_level=0).start()
if args.epochs > 0: if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider() train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders( train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
...@@ -271,21 +271,21 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -271,21 +271,21 @@ def finetune(train_valid_datasets_provider, model_provider,
timers('train/valid/test dataset/dataloder').stop() timers('train/valid/test dataset/dataloder').stop()
# Build calback function. # Build calback function.
timers('callback function').start() timers('callback function', log_level=0).start()
end_of_epoch_callback = None end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None: if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider() end_of_epoch_callback = end_of_epoch_callback_provider()
timers('callback function').stop() timers('callback function').stop()
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start() timers('model and optimizer', log_level=0).start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type) model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop() timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained # any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint. # checkpoint.
timers('pretrained checkpoint').start() timers('pretrained checkpoint', log_level=0).start(barrier=True)
if args.iteration == 0 and args.pretrained_checkpoint is not None: if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load original_load = args.load
args.load = args.pretrained_checkpoint args.load = args.pretrained_checkpoint
...@@ -302,7 +302,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -302,7 +302,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Print setup timing. # Print setup timing.
print_rank_0('done with setups ...') print_rank_0('done with setups ...')
timers.log(['train/valid/test dataset/dataloder', 'callback function', timers.log(['train/valid/test dataset/dataloder', 'callback function',
'model and optimizer', 'pretrained checkpoint']) 'model and optimizer', 'pretrained checkpoint'], barrier=True)
print_rank_0('training ...') print_rank_0('training ...')
# Finetune the model. # Finetune the model.
......
...@@ -63,7 +63,7 @@ def orqa(Dataset): ...@@ -63,7 +63,7 @@ def orqa(Dataset):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator', log_level=2).start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except BaseException: except BaseException:
......
...@@ -68,7 +68,7 @@ def classification(): ...@@ -68,7 +68,7 @@ def classification():
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers("batch generator").start() timers("batch generator", log_level=2).start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except BaseException: except BaseException:
......
...@@ -136,7 +136,7 @@ def _train( ...@@ -136,7 +136,7 @@ def _train(
report_memory_flag = True report_memory_flag = True
# For each remaining epoch # For each remaining epoch
timers("interval-time").start() timers("interval-time", log_level=0).start(barrier=True)
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
print_rank_0("working on epoch {} ...".format(epoch + 1)) print_rank_0("working on epoch {} ...".format(epoch + 1))
...@@ -218,7 +218,7 @@ def finetune( ...@@ -218,7 +218,7 @@ def finetune(
timers = get_timers() timers = get_timers()
# Train and validation data loaders. # Train and validation data loaders.
timers("train/valid/test dataset/dataloder").start() timers("train/valid/test dataset/dataloder", log_level=0).start()
if args.epochs > 0: if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider() train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders( train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
...@@ -227,14 +227,14 @@ def finetune( ...@@ -227,14 +227,14 @@ def finetune(
timers("train/valid/test dataset/dataloder").stop() timers("train/valid/test dataset/dataloder").stop()
# Build calback function. # Build calback function.
timers("callback function").start() timers("callback function", log_level=0).start()
end_of_epoch_callback = None end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None: if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider() end_of_epoch_callback = end_of_epoch_callback_provider()
timers("callback function").stop() timers("callback function").stop()
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start() timers("model and optimizer", log_level=0).start()
model, optimizer, opt_param_scheduler = \ model, optimizer, opt_param_scheduler = \
setup_model_and_optimizer( setup_model_and_optimizer(
model_provider, model_provider,
...@@ -246,7 +246,7 @@ def finetune( ...@@ -246,7 +246,7 @@ def finetune(
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained # any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint. # checkpoint.
timers("pretrained checkpoint").start() timers("pretrained checkpoint", log_level=0).start(barrier=True)
if args.iteration == 0 and args.pretrained_checkpoint is not None: if args.iteration == 0 and args.pretrained_checkpoint is not None:
if args.pretrained_checkpoint_type == 'default': if args.pretrained_checkpoint_type == 'default':
original_load = args.load original_load = args.load
......
...@@ -123,7 +123,7 @@ def segmentation(): ...@@ -123,7 +123,7 @@ def segmentation():
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers("batch generator").start() timers("batch generator", log_level=2).start()
import types import types
if isinstance(batch, types.GeneratorType): if isinstance(batch, types.GeneratorType):
batch_ = next(batch) batch_ = next(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment