Commit 77bff386 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Use helper method in megatron/schedules.py as intended

parent f3be8ead
...@@ -31,6 +31,9 @@ def get_forward_backward_func(): ...@@ -31,6 +31,9 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
......
...@@ -47,9 +47,7 @@ from megatron.utils import check_adlr_autoresume_termination ...@@ -47,9 +47,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining from megatron.schedules import get_forward_backward_func
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -359,16 +357,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -359,16 +357,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
optimizer.zero_grad() optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1: forward_backward_func = get_forward_backward_func()
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
...@@ -722,13 +711,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -722,13 +711,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
if mpu.get_pipeline_model_parallel_world_size() > 1: forward_backward_func = get_forward_backward_func()
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
loss_dicts = forward_backward_func( loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None, forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True) timers=None, forward_only=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment