Commit feecd5d9 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Add constant num micro-batches calculator

parent 6ea23928
......@@ -69,13 +69,13 @@ def parse_args(extra_args_provider=None, defaults={},
raise Exception('PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!')
# Checks.
args.model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % args.model_parallel_size == 0, 'world size is not'\
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // args.model_parallel_size
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0:
print('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
......
......@@ -29,15 +29,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
return None
args = get_args()
world_size = mpu.get_data_parallel_world_size()
# Megatron sampler
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
global_batch_size=args.global_batch_size,
rank=mpu.get_data_parallel_rank(),
world_size=world_size)
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
......@@ -49,13 +47,15 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples,
global_batch_size, rank, world_size):
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.global_batch_size = global_batch_size
self.rank = rank
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
......@@ -63,19 +63,11 @@ class MegatronPretrainingSampler:
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.global_batch_size > 0, \
'Unexpected global batch size: {}'.format(self.global_batch_size)
assert world_size > 0,\
'non zero world size is expected: {}'.format(world_size)
assert self.rank < world_size,\
'rank should be smaller than world size: {}, {}'.format(
self.rank, world_size)
# Batch size per rank.
assert self.global_batch_size % world_size == 0,\
'global batch size must be divisible by world size: {}, {}'.format(
self.global_batch_size, world_size)
self.batch_size_per_rank = self.global_batch_size // world_size
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
......@@ -87,8 +79,8 @@ class MegatronPretrainingSampler:
# Last batch if not complete will be dropped.
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.global_batch_size:
start_idx = self.rank * self.batch_size_per_rank
end_idx = start_idx + self.batch_size_per_rank
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx]
batch = []
......@@ -113,7 +113,7 @@ def _build_num_microbatches_calculator(args):
# Constant num micro-batches.
if args.rampup_batch_size is None:
micro_batch_times_data_parallel = args.micro_batch_size * \
arg.data_parallel_size
args.data_parallel_size
assert args.global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(args.global_batch_size,
......@@ -126,6 +126,7 @@ def _build_num_microbatches_calculator(args):
num_micro_batches), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
num_micro_batches)
return
raise Exception('should not be here.')
......
......@@ -46,7 +46,7 @@ def init_checkpointed_activations_memory_buffer():
args = get_args()
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size
args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
......
......@@ -835,8 +835,6 @@ def build_train_valid_test_data_iterators(
print_rank_0('> building train, validation, and test datasets ...')
# Rank and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
args.consumed_train_samples = args.iteration * args.global_batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment