Commit 9019bbf4 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Rename --batch-size to --micro-batch-size and drop in-minibatch from...

Rename --batch-size to --micro-batch-size and drop in-minibatch from --num-micro-batches-in-minibatch
parent a84a5fa0
......@@ -63,8 +63,6 @@ def parse_args(extra_args_provider=None, defaults={},
if "ring_exchange" not in dir(torch.distributed):
raise Exception('PyTorch with torch.distributed.ring_exchange needed '
'to run pipeline MP!')
if args.num_microbatches_in_minibatch is None:
args.num_microbatches_in_minibatch = 1
if args.rank == 0:
print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format(
args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size))
......@@ -212,11 +210,11 @@ def _add_regularization_args(parser):
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--batch-size', type=int, default=None,
group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size.')
group.add_argument('--num-microbatches-in-minibatch', type=int, default=None,
group.add_argument('--num-microbatches', type=int, default=1,
help='Number of microbatches in minibatch')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
......
......@@ -30,7 +30,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * world_size
global_batch_size = args.micro_batch_size * world_size
# Megatron sampler
batch_sampler = MegatronPretrainingSampler(
......
......@@ -9,15 +9,15 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_one_epoch_dataloader(dataset, batch_size=None):
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
global_batch_size = micro_batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
......
......@@ -80,7 +80,7 @@ __global__ void scaled_masked_softmax_warp_forward(
const input_t *src,
const uint8_t *mask,
const acc_t scale,
int batch_size,
int micro_batch_size,
int stride,
int element_count,
int pad_batches)
......@@ -102,9 +102,9 @@ __global__ void scaled_masked_softmax_warp_forward(
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}
// batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
......@@ -184,7 +184,7 @@ __global__ void scaled_masked_softmax_warp_backward(
input_t *grad,
const input_t *output,
acc_t scale,
int batch_size,
int micro_batch_size,
int stride,
int element_count)
{
......@@ -199,9 +199,9 @@ __global__ void scaled_masked_softmax_warp_backward(
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
......
......@@ -79,7 +79,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int batch_size,
int micro_batch_size,
int stride,
int element_count)
{
......@@ -94,9 +94,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE;
// batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
......@@ -173,7 +173,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
input_t *grad,
const input_t *output,
acc_t scale,
int batch_size,
int micro_batch_size,
int stride,
int element_count)
{
......@@ -187,9 +187,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_size - first_batch;
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
......
......@@ -45,7 +45,7 @@ def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.batch_size * args.max_position_embeddings * \
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
......
......@@ -138,7 +138,7 @@ def get_model(model_provider_func):
model = FP16_Module(model)
# Wrap model for distributed training."""
if args.num_microbatches_in_minibatch > 1:
if args.num_microbatches > 1:
assert args.DDP_impl == 'local'
if args.DDP_impl == 'torch':
......@@ -246,7 +246,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.batch_size, args.hidden_size)
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
......@@ -315,7 +315,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / args.num_microbatches_in_minibatch
output_tensor = loss / args.num_microbatches
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
......@@ -375,7 +375,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / args.num_microbatches_in_minibatch
output_tensor = loss / args.num_microbatches
output_tensor_grad = None
losses_reduced.append(loss_reduced)
else:
......@@ -419,10 +419,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
args = get_args()
losses_reduced = []
for i in range(args.num_microbatches_in_minibatch):
for i in range(args.num_microbatches):
timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss / args.num_microbatches_in_minibatch
output_tensor = loss / args.num_microbatches
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
......@@ -441,15 +441,15 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
args = get_args()
# Compute number of warmup microbatches.
num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
num_microbatches = args.num_microbatches
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches_in_minibatch)
num_microbatches_in_minibatch_remaining = \
num_microbatches_in_minibatch - num_warmup_microbatches
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
......@@ -465,7 +465,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_in_minibatch_remaining > 0:
if num_microbatches_remaining > 0:
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
......@@ -477,8 +477,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
timers('forward-recv').stop()
# Run 1F1B.
for i in range(num_microbatches_in_minibatch_remaining):
last_iteration = (i == (num_microbatches_in_minibatch_remaining - 1))
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
......@@ -702,8 +702,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.batch_size * \
args.num_microbatches_in_minibatch
args.micro_batch_size * \
args.num_microbatches
# Logging.
loss_scale = None
......@@ -761,7 +761,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
for _ in range(args.num_microbatches_in_minibatch):
for _ in range(args.num_microbatches):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
......@@ -788,13 +788,13 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
recv_backward=False)
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.batch_size \
* args.num_microbatches_in_minibatch
* args.micro_batch_size \
* args.num_microbatches
# Move model back to the train mode.
model.train()
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * args.num_microbatches_in_minibatch
total_loss_dict[key] /= args.eval_iters * args.num_microbatches
return total_loss_dict
......@@ -834,7 +834,7 @@ def build_train_valid_test_data_iterators(
# Rank and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size * args.num_microbatches_in_minibatch
global_batch_size = args.micro_batch_size * data_parallel_size * args.num_microbatches
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
args.consumed_train_samples = args.iteration * global_batch_size
......
......@@ -98,11 +98,11 @@ def get_ltor_masks_and_position_ids(data,
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = batch_size
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
......@@ -124,7 +124,7 @@ def get_ltor_masks_and_position_ids(data,
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(batch_size):
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
......
......@@ -87,8 +87,8 @@ def forward_step(data_iterator, model, input_tensor):
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that tensor_model_parallel_size == 1
micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
......
......@@ -37,7 +37,7 @@ def accuracy_func_provider(single_dataset_provider):
for datapath in datapaths:
dataset = single_dataset_provider(datapath)
dataloader = build_data_loader(
dataset, args.batch_size, num_workers=args.num_workers,
dataset, args.micro_batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader))
......
......@@ -71,7 +71,7 @@ def _cross_entropy_forward_step(batch, model):
return loss, {'lm loss': averaged_loss[0]}
def build_data_loader(dataset, batch_size, num_workers, drop_last):
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
......@@ -82,7 +82,7 @@ def build_data_loader(dataset, batch_size, num_workers, drop_last):
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
......@@ -109,14 +109,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.batch_size,
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.batch_size,
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
......
......@@ -186,7 +186,7 @@ def main():
# Data stuff.
dataset = build_dataset(args.task)
dataloader = build_data_loader(dataset, args.batch_size,
dataloader = build_data_loader(dataset, args.micro_batch_size,
args.num_workers, drop_last=False)
# Run evaluation.
......
......@@ -86,7 +86,7 @@ def main():
# Generate samples.
if args.num_samples == 0:
args.batch_size = 1
args.micro_batch_size = 1
if args.sample_input_file != None:
generate_samples_input_from_file(model)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment