Commit a58a2da6 authored by dongcl's avatar dongcl
Browse files

if forward_only is true, recv_backward should not be called

parent 2b81ee55
...@@ -19,7 +19,12 @@ from megatron.core.utils import ( ...@@ -19,7 +19,12 @@ from megatron.core.utils import (
from megatron.core.pipeline_parallel.schedules import clear_embedding_activation_buffer, deallocate_output_tensor from megatron.core.pipeline_parallel.schedules import clear_embedding_activation_buffer, deallocate_output_tensor
from megatron.core import ModelParallelConfig from megatron.core import ModelParallelConfig
from megatron.core.pipeline_parallel.p2p_communication import _communicate from megatron.core.pipeline_parallel.p2p_communication import _communicate
from megatron.core.pipeline_parallel.schedules import backward_step, set_current_microbatch, finish_embedding_wgrad_compute from megatron.core.pipeline_parallel.schedules import (
backward_step,
set_current_microbatch,
check_first_val_step,
finish_embedding_wgrad_compute
)
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore # from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
...@@ -114,7 +119,7 @@ def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelPa ...@@ -114,7 +119,7 @@ def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelPa
return reqs return reqs
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False, step=-1) -> torch.Tensor: def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive). """ Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details. See _communicate for argument details.
...@@ -565,9 +570,10 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -565,9 +570,10 @@ def forward_backward_pipelining_with_cutinhalf(
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size() tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda() total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
input_tensors = [[], []]
output_tensors = [[], []]
forward_data_store = [] forward_data_store = []
if not forward_only:
input_tensors = [[], []]
output_tensors = [[], []]
master_chunk_id = 0 master_chunk_id = 0
slave_chunk_id = 1 slave_chunk_id = 1
...@@ -728,17 +734,16 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -728,17 +734,16 @@ def forward_backward_pipelining_with_cutinhalf(
slave_cur_microbatch += 1 slave_cur_microbatch += 1
if i == schedule['interleaved_forward'][rank] - 1: if not forward_only:
firstFB_no_overlp = False if i == schedule['interleaved_forward'][rank] - 1:
firstFB_no_overlp_handle = None firstFB_no_overlp_handle = None
# last rank not overlap first F&B # last rank not overlap first F&B
if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
firstFB_no_overlp = True output_tensor_grad_bwd, firstFB_no_overlp_handle = recv_backward(
output_tensor_grad_bwd, firstFB_no_overlp_handle = recv_backward( tensor_shape, config, slave_chunk_id, async_op=True)
tensor_shape, config, slave_chunk_id, async_op=True) else:
else: output_tensor_grad_bwd, _ = recv_backward(
output_tensor_grad_bwd, _ = recv_backward( tensor_shape, config, slave_chunk_id)
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk, fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True) tensor_shape, config, slave_chunk_id, async_op=True)
...@@ -838,11 +843,14 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -838,11 +843,14 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk, fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True) tensor_shape, config, slave_chunk_id, async_op=True)
fwd_wait_handles_recv = None
# Run overlaping f&bw stages # Run overlaping f&bw stages
fwd_wait_handles_recv = None
fwd_model_chunk_id = master_chunk_id fwd_model_chunk_id = master_chunk_id
bwd_model_chunk_id = slave_chunk_id bwd_model_chunk_id = slave_chunk_id
for step_id in range(schedule['overlap'][rank] + schedule['1b1overlap'][rank] + schedule['interleaved_backward'][rank]): num_overlap_steps = schedule['overlap'][rank] + schedule['1b1overlap'][rank]
if not forward_only:
num_overlap_steps += schedule['interleaved_backward'][rank]
for step_id in range(num_overlap_steps):
only_bwd = False only_bwd = False
if fwd_model_chunk_id == master_chunk_id and master_cur_microbatch == master_microbatch_max: if fwd_model_chunk_id == master_chunk_id and master_cur_microbatch == master_microbatch_max:
only_bwd = True only_bwd = True
...@@ -853,6 +861,11 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -853,6 +861,11 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id) set_dualpipe_chunk(fwd_model_chunk_id)
if fwd_wait_handles_recv is not None:
for req, req_handle in fwd_wait_handles_recv.items():
req_handle.wait()
fwd_wait_handles_recv = None
output_tensor, num_tokens = forward_step_no_model_graph( output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func, forward_step_func,
fwd_model_chunk_id, fwd_model_chunk_id,
...@@ -906,7 +919,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -906,7 +919,7 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward( input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True) output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
if firstFB_no_overlp_handle is not None: if not forward_only and firstFB_no_overlp_handle is not None:
for req, req_handle in firstFB_no_overlp_handle.items(): for req, req_handle in firstFB_no_overlp_handle.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
...@@ -948,8 +961,8 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -948,8 +961,8 @@ def forward_backward_pipelining_with_cutinhalf(
# only run backward # only run backward
else: else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max: if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward( input_tensor, fwd_wait_handles_recv = recv_forward(
tensor_shape, config, slave_chunk_id) tensor_shape, config, slave_chunk_id, async_op=True)
if not forward_only: if not forward_only:
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items(): for req, req_handle in bwd_wait_handles.items():
......
...@@ -20,12 +20,16 @@ from megatron.core.num_microbatches_calculator import ( ...@@ -20,12 +20,16 @@ from megatron.core.num_microbatches_calculator import (
get_current_global_batch_size, get_current_global_batch_size,
get_current_running_global_batch_size, get_current_running_global_batch_size,
get_num_microbatches, get_num_microbatches,
update_num_microbatches) update_num_microbatches,
)
from megatron.training.async_utils import maybe_finalize_async_save from megatron.training.async_utils import maybe_finalize_async_save
from megatron.training.utils import ( from megatron.training.utils import (
calc_params_l2_norm, calc_params_l2_norm,
print_rank_0, print_rank_0,
logical_and_across_model_parallel_group,
reduce_max_stat_across_model_parallel_group,
unwrap_model,
) )
from megatron.training.global_vars import ( from megatron.training.global_vars import (
get_args, get_args,
...@@ -41,7 +45,6 @@ from megatron.training.training import ( ...@@ -41,7 +45,6 @@ from megatron.training.training import (
print_datetime, print_datetime,
should_disable_forward_pre_hook, should_disable_forward_pre_hook,
disable_forward_pre_hook, disable_forward_pre_hook,
train_step,
save_checkpoint_and_time, save_checkpoint_and_time,
enable_forward_pre_hook, enable_forward_pre_hook,
num_floating_point_operations, num_floating_point_operations,
...@@ -49,7 +52,12 @@ from megatron.training.training import ( ...@@ -49,7 +52,12 @@ from megatron.training.training import (
evaluate_and_print_results, evaluate_and_print_results,
post_training_step_callbacks, post_training_step_callbacks,
checkpoint_and_decide_exit, checkpoint_and_decide_exit,
cuda_graph_capture,
cuda_graph_set_manual_hooks,
dummy_train_step,
) )
from megatron.core.pipeline_parallel import get_forward_backward_func
stimer = StragglerDetector() stimer = StragglerDetector()
...@@ -77,6 +85,122 @@ def build_train_valid_test_data_iterators_wrapper(build_train_valid_test_data_it ...@@ -77,6 +85,122 @@ def build_train_valid_test_data_iterators_wrapper(build_train_valid_test_data_it
return wrapper return wrapper
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
# CUDA Graph capturing only executes once, when it's the first training iteration.
if args.curr_iteration == args.iteration and args.external_cuda_graph:
cuda_graph_capture(model, config, args)
# Set grad to zero.
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
# Collect garbage and empty unused memory.
gc.collect()
torch.cuda.empty_cache()
rerun_state_machine = get_rerun_state_machine()
while rerun_state_machine.should_run_forward_backward(data_iterator):
# Set grad to zero.
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False)
should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
# Set the manual hooks when CUDA Graphs are enabled.
if args.curr_iteration == args.iteration and args.external_cuda_graph:
if args.use_distributed_optimizer and args.overlap_param_gather:
cuda_graph_set_manual_hooks(model)
if args.schedule_method == 'dualpipev':
is_last_stage = mpu.is_pipeline_first_stage(ignore_virtual=True)
else:
is_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=True)
if is_last_stage:
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0].keys():
numerator = 0
denominator = 0
for x in losses_reduced:
val = x[key]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if isinstance(val, tuple) or isinstance(val, list):
numerator += val[0]
denominator += val[1]
else:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
def train(forward_step_func, model, optimizer, opt_param_scheduler, def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator, train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func): process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
......
...@@ -4,6 +4,15 @@ from megatron.training import get_args ...@@ -4,6 +4,15 @@ from megatron.training import get_args
from megatron.core import mpu from megatron.core import mpu
def print_rank_message(message, rank_id=0):
"""If distributed is initialized, print only on rank specified by rank_id."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == rank_id:
print(f"[rank {rank_id}] {message}", flush=True)
else:
print(f"[rank {rank_id}] {message}", flush=True)
def get_batch_on_this_tp_rank(data_iterator): def get_batch_on_this_tp_rank(data_iterator):
args = get_args() args = get_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment