Commit 6ea23928 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Add micro-batch size calculator

parent 9019bbf4
......@@ -26,6 +26,8 @@ from .package_info import (
)
from .global_vars import get_args
from .global_vars import get_num_microbatches
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
......
......@@ -54,18 +54,45 @@ def parse_args(extra_args_provider=None, defaults={},
# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
' ({}) is not divisible by tensor model parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size)
# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
if args.pipeline_model_parallel_size > 1:
if "ring_exchange" not in dir(torch.distributed):
raise Exception('PyTorch with torch.distributed.ring_exchange needed '
'to run pipeline MP!')
raise Exception('PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!')
# Checks.
args.model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % args.model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // args.model_parallel_size
if args.rank == 0:
print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format(
args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size))
print('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
# Fp16 loss scaling.
args.dynamic_loss_scale = False
......@@ -214,8 +241,22 @@ def _add_training_args(parser):
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size.')
group.add_argument('--num-microbatches', type=int, default=1,
help='Number of microbatches in minibatch')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
......
......@@ -23,7 +23,7 @@ import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import mpu, get_args
from megatron import mpu, get_args, update_num_microbatches
from megatron import get_args
from megatron import print_rank_0
......@@ -236,6 +236,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
else:
......
......@@ -30,13 +30,12 @@ def build_pretraining_data_loader(dataset, consumed_samples):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
global_batch_size = args.micro_batch_size * world_size
# Megatron sampler
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
global_batch_size=global_batch_size,
global_batch_size=args.global_batch_size,
rank=mpu.get_data_parallel_rank(),
world_size=world_size)
......
......@@ -15,6 +15,8 @@
"""Megatron global variables."""
from abc import ABC
from abc import abstractmethod
import os
import sys
import time
......@@ -25,18 +27,35 @@ from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
def get_args():
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
return _GLOBAL_ARGS
def get_num_microbatches_calculator():
"""Return num-microbatches calculator."""
_ensure_var_is_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
'number of micro-batches calculator.')
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def update_num_microbatches(consumed_samples):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples)
def get_tokenizer():
"""Return tokenizer."""
_ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
......@@ -67,6 +86,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
......@@ -84,6 +104,62 @@ def _parse_args(extra_args_provider=None, defaults={},
return _GLOBAL_ARGS
def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
'num microbatches calculator')
# Constant num micro-batches.
if args.rampup_batch_size is None:
micro_batch_times_data_parallel = args.micro_batch_size * \
arg.data_parallel_size
assert args.global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size)
num_micro_batches = args.global_batch_size // \
micro_batch_times_data_parallel
if args.rank == 0:
print('setting number of micro-batches to constant {}'.format(
num_micro_batches), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
num_micro_batches)
raise Exception('should not be here.')
class NumMicroBatchesCalculator(ABC):
def __init__(self, name):
self.name = name
super(NumMicroBatchesCalculator, self).__init__()
@abstractmethod
def get(self):
pass
def update(self, consumed_samples):
pass
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, num_micro_batches=1):
assert num_micro_batches >= 1
self.num_micro_batches = num_micro_batches
super(ConstantNumMicroBatches, self).__init__(
'constant: {}'.format(self.num_micro_batches))
def update(self, consumed_samples):
pass
def get(self):
return self.num_micro_batches
def _build_tokenizer(args):
"""Initialize tokenizer."""
global _GLOBAL_TOKENIZER
......
......@@ -25,6 +25,8 @@ from apex.optimizers import FusedAdam as Adam
from megatron import get_args
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import get_num_microbatches
from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
......@@ -137,10 +139,6 @@ def get_model(model_provider_func):
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training."""
if args.num_microbatches > 1:
assert args.DDP_impl == 'local'
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = torchDDP(model, device_ids=[i], output_device=i,
......@@ -225,6 +223,10 @@ def setup_model_and_optimizer(model_provider_func):
else:
args.iteration = 0
# Wrap model for distributed training."""
if get_num_microbatches() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers
unwrapped_model = model
while hasattr(unwrapped_model, 'module'):
......@@ -315,7 +317,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / args.num_microbatches
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
......@@ -375,7 +377,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / args.num_microbatches
output_tensor = loss / get_num_microbatches()
output_tensor_grad = None
losses_reduced.append(loss_reduced)
else:
......@@ -419,10 +421,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
args = get_args()
losses_reduced = []
for i in range(args.num_microbatches):
for i in range(get_num_microbatches()):
timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss / args.num_microbatches
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
......@@ -441,7 +443,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
args = get_args()
# Compute number of warmup microbatches.
num_microbatches = args.num_microbatches
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
......@@ -695,6 +697,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start()
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
......@@ -703,7 +706,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
args.num_microbatches
get_num_microbatches()
# Logging.
loss_scale = None
......@@ -761,7 +764,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
for _ in range(args.num_microbatches):
for _ in range(get_num_microbatches()):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
......@@ -789,12 +792,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \
* args.num_microbatches
* get_num_microbatches()
# Move model back to the train mode.
model.train()
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * args.num_microbatches
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict
......@@ -834,13 +837,12 @@ def build_train_valid_test_data_iterators(
# Rank and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.micro_batch_size * data_parallel_size * args.num_microbatches
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
args.consumed_train_samples = args.iteration * global_batch_size
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * global_batch_size
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
......@@ -849,9 +851,9 @@ def build_train_valid_test_data_iterators(
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
train_val_test_num_samples = [train_iters * args.global_batch_size,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment