Commit 23632ee5 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'interleaved_bugfix' into 'main'

Small bugfix to make sure refactored code works with interleaved schedule

See merge request ADLR/megatron-lm!256
parents 3fc035d7 6fd78189
......@@ -195,21 +195,23 @@ def get_model(model_provider_func):
"""Build the model."""
args = get_args()
# Build model on cpu.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment