"vscode:/vscode.git/clone" did not exist on "3f329a426a09d0bf3f96095301042a5903bc78eb"
Commit e3e5ea89 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Compute tensor chunk size more cleanly, and add assertion for global batch size

parent 27fc4689
......@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \
' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size
......@@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={},
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
if args.virtual_pipeline_model_parallel_size is not None:
assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
'global batch size is not divisible by pipeline parallel size when '\
'using interleaved schedule'
# Parameters dtype.
args.params_dtype = torch.float
......
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import reduce
import operator
import torch
from megatron import get_args
......@@ -30,9 +32,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (
args.seq_length * args.micro_batch_size * args.hidden_size) // \
mpu.get_tensor_model_parallel_world_size()
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment