Commit 6a68502d authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Minor fixes for batch size rampup

parent de0b70a0
...@@ -120,9 +120,25 @@ def _build_num_microbatches_calculator(args): ...@@ -120,9 +120,25 @@ def _build_num_microbatches_calculator(args):
num_micro_batches), flush=True) num_micro_batches), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches( _GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
num_micro_batches) num_micro_batches)
return
raise Exception('should not be here.') else:
assert len(args.rampup_batch_size) == 3, 'expected the following ' \
'format: --rampup-batch-size <start batch size> ' \
'<batch size incerement> <ramp-up samples>'
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
print('will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'.format(start_batch_size,
args.global_batch_size,
batch_size_increment,
ramup_samples), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = RampupBatchsizeNumMicroBatches(
start_batch_size, batch_size_increment, ramup_samples,
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
class NumMicroBatchesCalculator(ABC): class NumMicroBatchesCalculator(ABC):
...@@ -143,10 +159,10 @@ class NumMicroBatchesCalculator(ABC): ...@@ -143,10 +159,10 @@ class NumMicroBatchesCalculator(ABC):
class ConstantNumMicroBatches(NumMicroBatchesCalculator): class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, num_micro_batches=1): def __init__(self, num_micro_batches=1):
super(ConstantNumMicroBatches, self).__init__(
'constant: {}'.format(num_micro_batches))
assert num_micro_batches >= 1 assert num_micro_batches >= 1
self.num_micro_batches = num_micro_batches self.num_micro_batches = num_micro_batches
super(ConstantNumMicroBatches, self).__init__(
'constant: {}'.format(self.num_micro_batches))
def update(self, consumed_samples): def update(self, consumed_samples):
pass pass
...@@ -172,6 +188,10 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): ...@@ -172,6 +188,10 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
data_parallel_size: data parallel size. data_parallel_size: data parallel size.
""" """
super(RampupBatchsizeNumMicroBatches, self).__init__(
'batch size ramup: {}, {}, {}'.format(
start_batch_size, batch_size_increment, ramup_samples))
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment