Commit 7ce373f3 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Bugfix in megatron/training.py: correct global_batch_size computation

Prevents data_loader from running out of training examples
parent 9d4c735a
...@@ -716,7 +716,7 @@ def build_train_valid_test_data_iterators( ...@@ -716,7 +716,7 @@ def build_train_valid_test_data_iterators(
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
# Rank, size, and global batch size. # Rank, size, and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size() data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size global_batch_size = args.batch_size * data_parallel_size * args.num_microbatches_in_minibatch
# Number of train/valid/test samples. # Number of train/valid/test samples.
train_iters = args.train_iters train_iters = args.train_iters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment