Commit 13b3dca6 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

make interleaving work with optimizations

parent eec218d8
......@@ -279,8 +279,12 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args()
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.model_parallel_memory_opt:
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment