"graphbolt/src/vscode:/vscode.git/clone" did not exist on "a272efed96f0a42241235922236519ad8d0ba895"
Commit feecd5d9 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Add constant num micro-batches calculator

parent 6ea23928
...@@ -69,13 +69,13 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -69,13 +69,13 @@ def parse_args(extra_args_provider=None, defaults={},
raise Exception('PyTorch with torch.distributed.ring_exchange ' raise Exception('PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!') 'needed to run pipeline MP!')
# Checks. # Checks.
args.model_parallel_size = args.pipeline_model_parallel_size * \ model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size args.tensor_model_parallel_size
assert args.world_size % args.model_parallel_size == 0, 'world size is not'\ assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \ ' divisible by tensor parallel size ({}) times pipeline paralle ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size, 'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size) args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // args.model_parallel_size args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0: if args.rank == 0:
print('using world size: {}, data-parallel-size: {}, ' print('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, ' 'tensor-model-parallel size: {}, '
......
...@@ -29,15 +29,13 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -29,15 +29,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
return None return None
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size()
# Megatron sampler # Megatron sampler
batch_sampler = MegatronPretrainingSampler( batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset), total_samples=len(dataset),
consumed_samples=consumed_samples, consumed_samples=consumed_samples,
global_batch_size=args.global_batch_size, micro_batch_size=args.micro_batch_size,
rank=mpu.get_data_parallel_rank(), data_parallel_rank=mpu.get_data_parallel_rank(),
world_size=world_size) data_parallel_size=mpu.get_data_parallel_world_size())
# Torch dataloader. # Torch dataloader.
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
...@@ -49,13 +47,15 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -49,13 +47,15 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, def __init__(self, total_samples, consumed_samples, micro_batch_size,
global_batch_size, rank, world_size): data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
self.total_samples = total_samples self.total_samples = total_samples
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
self.global_batch_size = global_batch_size self.micro_batch_size = micro_batch_size
self.rank = rank self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
data_parallel_size
# Sanity checks. # Sanity checks.
assert self.total_samples > 0, \ assert self.total_samples > 0, \
...@@ -63,19 +63,11 @@ class MegatronPretrainingSampler: ...@@ -63,19 +63,11 @@ class MegatronPretrainingSampler:
assert self.consumed_samples < self.total_samples, \ assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples, 'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples) self.total_samples)
assert self.global_batch_size > 0, \ assert self.micro_batch_size > 0
'Unexpected global batch size: {}'.format(self.global_batch_size) assert data_parallel_size > 0
assert world_size > 0,\ assert self.data_parallel_rank < data_parallel_size, \
'non zero world size is expected: {}'.format(world_size) 'data_parallel_rank should be smaller than data size: {}, ' \
assert self.rank < world_size,\ '{}'.format(self.data_parallel_rank, data_parallel_size)
'rank should be smaller than world size: {}, {}'.format(
self.rank, world_size)
# Batch size per rank.
assert self.global_batch_size % world_size == 0,\
'global batch size must be divisible by world size: {}, {}'.format(
self.global_batch_size, world_size)
self.batch_size_per_rank = self.global_batch_size // world_size
def __len__(self): def __len__(self):
...@@ -87,8 +79,8 @@ class MegatronPretrainingSampler: ...@@ -87,8 +79,8 @@ class MegatronPretrainingSampler:
# Last batch if not complete will be dropped. # Last batch if not complete will be dropped.
for idx in range(self.consumed_samples, self.total_samples): for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx) batch.append(idx)
if len(batch) == self.global_batch_size: if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx = self.rank * self.batch_size_per_rank start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.batch_size_per_rank end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
batch = [] batch = []
...@@ -113,7 +113,7 @@ def _build_num_microbatches_calculator(args): ...@@ -113,7 +113,7 @@ def _build_num_microbatches_calculator(args):
# Constant num micro-batches. # Constant num micro-batches.
if args.rampup_batch_size is None: if args.rampup_batch_size is None:
micro_batch_times_data_parallel = args.micro_batch_size * \ micro_batch_times_data_parallel = args.micro_batch_size * \
arg.data_parallel_size args.data_parallel_size
assert args.global_batch_size % micro_batch_times_data_parallel == 0, \ assert args.global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \ 'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(args.global_batch_size, ' times data parallel size ({})'.format(args.global_batch_size,
...@@ -126,6 +126,7 @@ def _build_num_microbatches_calculator(args): ...@@ -126,6 +126,7 @@ def _build_num_microbatches_calculator(args):
num_micro_batches), flush=True) num_micro_batches), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches( _GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
num_micro_batches) num_micro_batches)
return
raise Exception('should not be here.') raise Exception('should not be here.')
......
...@@ -46,7 +46,7 @@ def init_checkpointed_activations_memory_buffer(): ...@@ -46,7 +46,7 @@ def init_checkpointed_activations_memory_buffer():
args = get_args() args = get_args()
per_layer = args.micro_batch_size * args.max_position_embeddings * \ per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \ assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers' 'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
......
...@@ -835,8 +835,6 @@ def build_train_valid_test_data_iterators( ...@@ -835,8 +835,6 @@ def build_train_valid_test_data_iterators(
print_rank_0('> building train, validation, and test datasets ...') print_rank_0('> building train, validation, and test datasets ...')
# Rank and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
# Backward compatibility, assume fixed batch size. # Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0: if args.iteration > 0 and args.consumed_train_samples == 0:
args.consumed_train_samples = args.iteration * args.global_batch_size args.consumed_train_samples = args.iteration * args.global_batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment