Commit e3e5ea89 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Compute tensor chunk size more cleanly, and add assertion for global batch size

parent 27fc4689
......@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \
' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size
......@@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={},
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
if args.virtual_pipeline_model_parallel_size is not None:
assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
'global batch size is not divisible by pipeline parallel size when '\
'using interleaved schedule'
# Parameters dtype.
args.params_dtype = torch.float
......
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
import torch
from megatron import get_args
......@@ -30,9 +32,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (
args.seq_length * args.micro_batch_size * args.hidden_size) // \
mpu.get_tensor_model_parallel_world_size()
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment