Commit c30ba0f7 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Minor refactoring

parent feecd5d9
...@@ -244,7 +244,8 @@ def _add_training_args(parser): ...@@ -244,7 +244,8 @@ def _add_training_args(parser):
group.add_argument('--global-batch-size', type=int, default=None, group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If this value is None, then ' help='Training batch size. If this value is None, then '
'use micro-batch-size * data-parallel-size as the ' 'use micro-batch-size * data-parallel-size as the '
'global batch size') 'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None, group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:' help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> ' ' --rampup-batch-size <start batch size> '
......
...@@ -34,20 +34,12 @@ _GLOBAL_ADLR_AUTORESUME = None ...@@ -34,20 +34,12 @@ _GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None _GLOBAL_TIMERS = None
def get_args(): def get_args():
"""Return arguments.""" """Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args') _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
return _GLOBAL_ARGS return _GLOBAL_ARGS
def get_num_microbatches_calculator():
"""Return num-microbatches calculator."""
_ensure_var_is_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
'number of micro-batches calculator.')
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR
def get_num_microbatches(): def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
...@@ -141,6 +133,7 @@ class NumMicroBatchesCalculator(ABC): ...@@ -141,6 +133,7 @@ class NumMicroBatchesCalculator(ABC):
def get(self): def get(self):
pass pass
@abstractmethod
def update(self, consumed_samples): def update(self, consumed_samples):
pass pass
......
...@@ -223,7 +223,7 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -223,7 +223,7 @@ def setup_model_and_optimizer(model_provider_func):
else: else:
args.iteration = 0 args.iteration = 0
# Wrap model for distributed training.""" # We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1: if get_num_microbatches() > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment