Commit 6a68502d authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Minor fixes for batch size rampup

parent de0b70a0
......@@ -120,9 +120,25 @@ def _build_num_microbatches_calculator(args):
num_micro_batches), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
num_micro_batches)
return
raise Exception('should not be here.')
else:
assert len(args.rampup_batch_size) == 3, 'expected the following ' \
'format: --rampup-batch-size <start batch size> ' \
'<batch size incerement> <ramp-up samples>'
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
print('will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'.format(start_batch_size,
args.global_batch_size,
batch_size_increment,
ramup_samples), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = RampupBatchsizeNumMicroBatches(
start_batch_size, batch_size_increment, ramup_samples,
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
class NumMicroBatchesCalculator(ABC):
......@@ -143,10 +159,10 @@ class NumMicroBatchesCalculator(ABC):
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, num_micro_batches=1):
super(ConstantNumMicroBatches, self).__init__(
'constant: {}'.format(num_micro_batches))
assert num_micro_batches >= 1
self.num_micro_batches = num_micro_batches
super(ConstantNumMicroBatches, self).__init__(
'constant: {}'.format(self.num_micro_batches))
def update(self, consumed_samples):
pass
......@@ -172,6 +188,10 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
data_parallel_size: data parallel size.
"""
super(RampupBatchsizeNumMicroBatches, self).__init__(
'batch size ramup: {}, {}, {}'.format(
start_batch_size, batch_size_increment, ramup_samples))
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment