Commit e3e5ea89 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Compute tensor chunk size more cleanly, and add assertion for global batch size

parent 27fc4689
...@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size = args.pipeline_model_parallel_size * \ model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\ assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \ ' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size, 'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size) args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size args.data_parallel_size = args.world_size // model_parallel_size
...@@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={},
print('setting global batch size to {}'.format( print('setting global batch size to {}'.format(
args.global_batch_size), flush=True) args.global_batch_size), flush=True)
assert args.global_batch_size > 0 assert args.global_batch_size > 0
if args.virtual_pipeline_model_parallel_size is not None:
assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
'global batch size is not divisible by pipeline parallel size when '\
'using interleaved schedule'
# Parameters dtype. # Parameters dtype.
args.params_dtype = torch.float args.params_dtype = torch.float
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import reduce
import operator
import torch import torch
from megatron import get_args from megatron import get_args
...@@ -30,9 +32,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -30,9 +32,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_recv_next = None tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline: if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = ( tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
args.seq_length * args.micro_batch_size * args.hidden_size) // \ mpu.get_tensor_model_parallel_world_size()
mpu.get_tensor_model_parallel_world_size()
else: else:
tensor_chunk_shape = tensor_shape tensor_chunk_shape = tensor_shape
dtype = args.params_dtype dtype = args.params_dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment