Commit ff5427cf authored by dongcl's avatar dongcl
Browse files

add dualpipev

parent 6dcd0fb8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import contextlib
from functools import wraps
from typing import Iterator, List, Union
import torch
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.training import get_args
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
get_attr_wrapped_model,
get_model_config,
get_model_type,
)
from megatron.core.pipeline_parallel.schedules import clear_embedding_activation_buffer, deallocate_output_tensor
from megatron.core import ModelParallelConfig
from megatron.core.pipeline_parallel.p2p_communication import _communicate
from megatron.core.pipeline_parallel.schedules import backward_step, set_current_microbatch, custom_backward, finish_embedding_wgrad_compute
from megatron.core.models.gpt import GPTModel
from mindspeed.core.pipeline_parallel.fb_overlap.gpt_model import gpt_model_backward
from mindspeed.core.pipeline_parallel.fb_overlap.transformer_layer import P2PCommParams
from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
# Types
Shape = Union[List[int], torch.Size]
LOSS_BACKWARD_SCALE = torch.tensor(1.0)
_DUALPIPE_CHUNK = None
def set_dualpipe_chunk(chunkid):
"""set_dualpipe_chunk for fp16forward patch"""
global _DUALPIPE_CHUNK
_DUALPIPE_CHUNK = chunkid
def get_dualpipe_chunk():
global _DUALPIPE_CHUNK
if _DUALPIPE_CHUNK is not None:
return _DUALPIPE_CHUNK
else:
raise AssertionError("_DUALPIPE_CHUNK is None")
def is_dualpipev_last_stgae(model_chunk_id):
return parallel_state.is_pipeline_first_stage(ignore_virtual=True) and model_chunk_id == 1
def send_forward(output_tensor: torch.Tensor, tensor_shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_last_stage():
return None
tensor_send_next = output_tensor
else:
if parallel_state.is_pipeline_first_stage():
return None
tensor_send_prev = output_tensor
if config.timers is not None:
config.timers('forward-send', log_level=2).start()
_, _, fwd_wait_handles = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
wait_on_reqs=(not async_op)
)
if config.timers is not None:
config.timers('forward-send').stop()
return fwd_wait_handles
def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_first_stage():
return None
tensor_send_prev = input_tensor_grad
else:
if parallel_state.is_pipeline_last_stage():
return None
tensor_send_next = input_tensor_grad
if config.timers is not None:
config.timers('backward-send', log_level=2).start()
_, _, reqs = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
wait_on_reqs=(not async_op)
)
if config.timers is not None:
config.timers('backward-send').stop()
return reqs
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
if model_chunk_id == 0:
recv_prev = True
else:
recv_next = True
if (parallel_state.is_pipeline_first_stage() and recv_prev) or (parallel_state.is_pipeline_last_stage() and recv_next):
fwd_wait_handles = None
return None, fwd_wait_handles
else:
if config.timers is not None:
config.timers('forward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
wait_on_reqs=(not async_op),
)
if config.timers is not None:
config.timers('forward-recv').stop()
if recv_prev:
return tensor_recv_prev, fwd_wait_handles
else:
return tensor_recv_next, fwd_wait_handles
def recv_backward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
if model_chunk_id == 0:
recv_next = True
else:
recv_prev = True
if (parallel_state.is_pipeline_first_stage() and recv_prev) or (parallel_state.is_pipeline_last_stage() and recv_next):
output_tensor_grad = None
bwd_wait_handles = None
return output_tensor_grad, bwd_wait_handles
else:
if config.timers is not None:
config.timers('backward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, bwd_wait_handles = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
wait_on_reqs=(not async_op)
)
if config.timers is not None:
config.timers('backward-recv').stop()
if recv_prev:
return tensor_recv_prev, bwd_wait_handles
else:
return tensor_recv_next, bwd_wait_handles
def send_forward_recv_forward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig,
model_chunk_id,
async_op=False
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if not parallel_state.is_pipeline_last_stage():
tensor_send_next = output_tensor
if not parallel_state.is_pipeline_first_stage():
recv_prev = True
if model_chunk_id == 1:
if not parallel_state.is_pipeline_first_stage():
tensor_send_prev = output_tensor
if not parallel_state.is_pipeline_last_stage():
recv_next = True
if config.timers is not None:
config.timers('forward-send-forward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
wait_on_reqs=(not async_op),
config=config
)
if config.timers is not None:
config.timers('forward-send-forward-recv').stop()
if model_chunk_id == 0:
if not parallel_state.is_pipeline_first_stage():
return tensor_recv_prev, fwd_wait_handles
else:
return None, fwd_wait_handles
else:
if not parallel_state.is_pipeline_last_stage():
return tensor_recv_next, fwd_wait_handles
else:
return None, fwd_wait_handles
def send_forward_recv_slave_forward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig,
model_chunk_id,
async_op=False,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_last_stage():
return None, None
tensor_send_next = output_tensor
recv_next = True
if model_chunk_id == 1:
if parallel_state.is_pipeline_first_stage():
return None, None
tensor_send_prev = output_tensor
recv_prev = True
if config.timers is not None:
config.timers('forward-send-slave-forward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
wait_on_reqs=(not async_op),
config=config,
)
if config.timers is not None:
config.timers('forward-send-slave-forward-recv').stop()
if model_chunk_id == 0:
return tensor_recv_next, fwd_wait_handles
else:
return tensor_recv_prev, fwd_wait_handles
def generate_dualpipev_schedule(pp_size, num_microbatches):
num_microbatches = num_microbatches * 2
num_warmup_stages = [0] * pp_size
num_interleaved_forward_stages = [0] * pp_size
num_1b1w1f_stages = [0] * pp_size
num_overlap_stages = [0] * pp_size
num_1b1overlap_stages = [0] * pp_size
num_interleaved_backward_stages = [0] * pp_size
num_cooldown_stages = [0] * pp_size
pp_size *= 2
for i in range(pp_size // 2):
num_warmup_stages[i] = pp_size - 2 - i * 2
num_interleaved_forward_stages[i] = i + 1 # 每个单位是一组1f1f
num_1b1w1f_stages[i] = pp_size // 2 - i - 1
num_overlap_stages[i] = num_microbatches - pp_size * 2 + i * 2 + 2
num_1b1overlap_stages[i] = (pp_size // 2 - i - 1) * 2
num_interleaved_backward_stages[i] = i + 1
num_cooldown_stages[i] = [i + 1, pp_size - 2 * i - 2, i + 1]
schedule_all_stages = {
'warmup': num_warmup_stages,
'interleaved_forward': num_interleaved_forward_stages,
'1b1w1f': num_1b1w1f_stages,
'overlap': num_overlap_stages,
'1b1overlap': num_1b1overlap_stages,
'interleaved_backward': num_interleaved_backward_stages,
'cooldown': num_cooldown_stages
}
return schedule_all_stages
def pretrain_gpt_forward_step_dualpipe(data_iterator, model: GPTModel, extra_block_kwargs=None):
from megatron.training import get_timers
from functools import partial
from pretrain_gpt import get_batch, loss_func
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
if extra_block_kwargs is not None:
# excute forward backward overlaping
output_tensor, model_graph, pp_comm_output = \
model(tokens, position_ids, attention_mask, labels=labels,
extra_block_kwargs=extra_block_kwargs)
return (output_tensor, model_graph, pp_comm_output), partial(loss_func, loss_mask)
else:
output_tensor, model_graph = model(
tokens, position_ids, attention_mask, labels=labels)
return (output_tensor, model_graph), partial(loss_func, loss_mask)
def forward_step_no_model_graph(
forward_step_func,
model_chunk_id,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
):
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
num_tokens = torch.tensor(0, dtype=torch.int)
if is_dualpipev_last_stgae:
if not collect_non_loss_data:
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.tensor(1.0)
)
# Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
def backward_step_with_model_graph(input_tensor, output_tensor, output_tensor_grad, model_type, config, model_graph=None):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if config.timers is not None:
config.timers('backward-compute', log_level=2).start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if output_tensor_grad[0] is None and config.grad_scale_func is not None and model_graph is None:
output_tensor[0] = config.grad_scale_func(output_tensor[0])
if config.deallocate_pipeline_outputs:
if model_graph is None:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
layer_output_grad = gpt_model_backward(
output_tensor_grad[0], model_graph)
else:
torch.autograd.backward(
output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = []
if model_graph is not None:
input_tensor_grad.append(layer_output_grad)
else:
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
if config.timers is not None:
config.timers('backward-compute').stop()
return input_tensor_grad
def forward_step_with_model_graph(
forward_step_func,
model_chunk_id,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
extra_block_kwargs=None,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable): The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally.
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator): The data iterator.
model (nn.Module): The model to perform the forward step on.
num_microbatches (int): The number of microbatches.
input_tensor (Tensor or list[Tensor]): The input tensor(s) for the forward step.
forward_data_store (list): The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object): The configuration object.
collect_non_loss_data (bool, optional): Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional): The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional): Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional): The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = pretrain_gpt_forward_step_dualpipe(
data_iterator, model, extra_block_kwargs)
else:
output_tensor, loss_func = pretrain_gpt_forward_step_dualpipe(
data_iterator, model, checkpoint_activations_microbatch, extra_block_kwargs
)
num_tokens = torch.tensor(0, dtype=torch.int)
if is_dualpipev_last_stgae(model_chunk_id):
if not collect_non_loss_data:
next_info = None
if isinstance(output_tensor, tuple):
# use pp overlaping,
if len(output_tensor) == 2:
output_tensor, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor, model_graph, next_info = output_tensor
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
output_tensor = (output_tensor, model_graph, next_info) if next_info is not None else (
output_tensor, model_graph)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(LOSS_BACKWARD_SCALE)
if config.grad_scale_func is not None
else torch.tensor(1.0)
)
# Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
shared_embedding = None
def get_shared_embedding_from_dual_chunk():
assert shared_embedding is not None
return shared_embedding
def set_shared_embedding_from_dual_chunk(model1, model2):
global shared_embedding
if shared_embedding is not None:
return
if model1.module.module.pre_process:
shared_embedding = model1.module.module.embedding.word_embeddings.weight
elif model2.module.module.pre_process:
shared_embedding = model2.module.module.embedding.word_embeddings.weight
def forward_backward_pipelining_with_cutinhalf(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
args = get_args()
args.moe_fb_overlap = True
args.dualpipe_no_dw_detach = True
set_shared_embedding_from_dual_chunk(model[0], model[1])
assert (
isinstance(model, list) and len(model) == 2
), 'Dualpipe Schedule only support chunk model for two consecutive chunks'
assert (
isinstance(data_iterator, list) and len(data_iterator) == 2
), 'Dualpipe Schedule only support two data_iterators'
config = get_model_config(model[0])
config.batch_p2p_comm = False
# Needed only when gradients are finalized in M-Core
if config.finalize_model_grads_func is not None and not forward_only:
embedding_module = clear_embedding_activation_buffer(config, model)
if config.timers is not None:
config.timers('forward-backward',
log_level=1).start(barrier=config.barrier_with_L1_time)
# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
# Compute number of steps for each stage
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank = parallel_state.get_pipeline_model_parallel_rank()
schedule = generate_dualpipev_schedule(pp_size, num_microbatches)
model_type = get_model_type(model[0])
tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
input_tensors = [[], []]
output_tensors = [[], []]
model_graphs = [[], []]
logits_inputs = []
forward_data_store = []
master_chunk_id = 0
slave_chunk_id = 1
master_cur_microbatch = 0
slave_cur_microbatch = num_microbatches
master_microbatch_max = num_microbatches
slave_microbatch_max = num_microbatches * 2
set_dualpipe_chunk(master_chunk_id)
checkpoint_activations_microbatch = None
def forward_step_helper(model_chunk_id, current_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=False, extra_block_kwargs=None):
input_tensor = input_tensors[model_chunk_id][-1][1]
output_tensor, num_tokens = forward_step_with_model_graph(
forward_step_func,
model_chunk_id,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch,
current_microbatch=current_microbatch,
extra_block_kwargs=extra_block_kwargs
)
if isinstance(output_tensor, tuple):
if len(output_tensor) == 2:
output_tensor_, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor_, model_graph, pp_comm_output = output_tensor
if is_dualpipev_last_stgae(model_chunk_id):
logits_inputs.append(
model_graph.layer_graphs[-1].unperm2_graph[1])
model_graphs[model_chunk_id].append(model_graph)
else:
output_tensor_ = output_tensor
output_tensors[model_chunk_id].append(output_tensor_)
if extra_block_kwargs is not None:
input_tensors[1 - model_chunk_id].pop(0)
output_tensors[1 - model_chunk_id].pop(0)
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def check_pipeline_stage(model_chunk_id, fwd_send_only):
send_next, recv_next, send_prev, recv_prev = True, True, True, True
if parallel_state.is_pipeline_first_stage():
send_prev, recv_prev = False, False
if parallel_state.is_pipeline_last_stage():
send_next, recv_next = False, False
if model_chunk_id == 0:
return P2PCommParams(send_next=send_next, recv_next=not fwd_send_only and recv_next), P2PCommParams(send_next=send_next, recv_next=recv_next)
else:
return P2PCommParams(send_prev=send_prev, recv_prev=not fwd_send_only and recv_prev), P2PCommParams(send_prev=send_prev, recv_prev=recv_prev)
input_tensor = recv_forward(tensor_shape, config, master_chunk_id)[0]
fwd_wait_handles_warmup = None
# Run warmup forward passes
for i in range(schedule['warmup'][rank]):
if args.moe_fb_overlap:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensor_warmup, _ = forward_step_helper(master_chunk_id, master_cur_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=(i == 0))
else:
output_tensor_warmup, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=(i == 0),
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
master_cur_microbatch += 1
if i != schedule['warmup'][rank] - 1:
input_tensor, _ = send_forward_recv_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id)
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
else:
input_tensor, _ = recv_forward(
tensor_shape, config, master_chunk_id)
fwd_wait_handles_warmup = send_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id, async_op=True)
# Run interleaved forward passes for two model chunk
fwd_wait_handles = None
fwd_wait_handles_slave_chunk = None
fwd_wait_handles_send = None
for i in range(schedule['interleaved_forward'][rank]):
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
is_first_microbatch = parallel_state.is_pipeline_last_stage() and (i == 0)
set_dualpipe_chunk(master_chunk_id)
if args.moe_fb_overlap:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensor, _ = forward_step_helper(master_chunk_id, master_cur_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch)
else:
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch,
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
master_cur_microbatch += 1
if not parallel_state.is_pipeline_last_stage() and fwd_wait_handles_send is not None:
for req in fwd_wait_handles_send:
req.wait()
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if parallel_state.is_pipeline_last_stage():
input_tensor_slave_chunk = output_tensor
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
else:
input_tensor_slave_chunk, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
if fwd_wait_handles_warmup is not None:
for req in fwd_wait_handles_warmup:
req.wait()
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
fwd_wait_handles_warmup = None
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
set_dualpipe_chunk(slave_chunk_id)
if args.moe_fb_overlap:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensor_slave_chunk, _ = forward_step_helper(
slave_chunk_id, slave_cur_microbatch, checkpoint_activations_microbatch)
else:
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch,
)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
if i == schedule['interleaved_forward'][rank] - 1:
firstFB_no_overlp = False
firstFB_no_overlp_handle = None
# last rank not overlap first F&B
if parallel_state.is_pipeline_last_stage():
firstFB_no_overlp = True
output_tensor_grad_bwd, firstFB_no_overlp_handle = recv_backward(
tensor_shape, config, slave_chunk_id, async_op=True)
else:
output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True)
if not parallel_state.is_pipeline_last_stage():
output_tensor_send = output_tensor
fwd_wait_handles_send = send_forward(
output_tensor_send, tensor_shape, config, master_chunk_id, async_op=True)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
# Run 1b1w1f stages for slave chunk
bwd_wait_handles = None
for _ in range(schedule['1b1w1f'][rank]):
WeightGradStore.start_decouple()
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(slave_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[slave_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
model_graph = model_graphs[slave_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
WeightGradStore.end_decouple()
# If asynchronous, the memory will rise.
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id)
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if fwd_wait_handles_send is not None:
for req in fwd_wait_handles_send:
req.wait()
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
# If asynchronous, the memory will rise.
input_tensor_slave_chunk, recv_forward_handle = recv_forward(
tensor_shape, config, slave_chunk_id)
# 1w: Weight Grad Compute
WeightGradStore.pop()
if recv_forward_handle is not None:
for req in recv_forward_handle:
req.wait()
recv_forward_handle = None
# 1F: Forward pass
set_dualpipe_chunk(slave_chunk_id)
if args.moe_fb_overlap:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensor_slave_chunk, _ = forward_step_helper(
slave_chunk_id, slave_cur_microbatch, checkpoint_activations_microbatch)
else:
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch
)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True)
fwd_wait_handles_recv = None
# Run overlaping f&bw stages
fwd_model_chunk_id = master_chunk_id
bwd_model_chunk_id = slave_chunk_id
for _ in range(schedule['overlap'][rank] + schedule['1b1overlap'][rank] + schedule['interleaved_backward'][rank]):
only_bwd = False
if fwd_model_chunk_id == master_chunk_id and master_cur_microbatch == master_microbatch_max:
only_bwd = True
if fwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch == slave_microbatch_max:
only_bwd = True
if args.moe_fb_overlap and not firstFB_no_overlp:
if not only_bwd:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
if fwd_wait_handles_recv is not None:
for req in fwd_wait_handles_recv:
req.wait()
fwd_wait_handles_recv = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if not parallel_state.is_pipeline_last_stage() or fwd_model_chunk_id == master_chunk_id:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id)
fwd_send_only = False
if fwd_model_chunk_id == slave_chunk_id and master_cur_microbatch == master_microbatch_max:
fwd_send_only = True
extra_block_kwargs = {}
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
extra_block_kwargs.setdefault(
'bwd_model_grad', input_tensor_grad)
else:
extra_block_kwargs.setdefault(
'bwd_model_grad', output_tensor_grad_bwd)
fwd_pp_comm_params, bwd_pp_comm_params = check_pipeline_stage(
fwd_model_chunk_id, fwd_send_only)
fwd_pp_comm_params.config, bwd_pp_comm_params.config = config, config
fwd_pp_comm_params.tensor_shape, bwd_pp_comm_params.tensor_shape = tensor_shape, tensor_shape
extra_block_kwargs.setdefault(
'bwd_model_graph', model_graphs[bwd_model_chunk_id].pop(0))
extra_block_kwargs.setdefault(
'pp_comm_params', fwd_pp_comm_params)
extra_block_kwargs.setdefault(
'bwd_pp_comm_params', bwd_pp_comm_params)
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensor, model_graph, pp_comm_output = forward_step_helper(fwd_model_chunk_id, fwd_microbatch, checkpoint_activations_microbatch,
extra_block_kwargs=extra_block_kwargs)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor
output_tensor_grad_bwd = pp_comm_output.input_tensor_grad
else:
input_tensor, fwd_wait_handles = pp_comm_output.input_tensor, pp_comm_output.fwd_wait_handles
output_tensor_grad_bwd, bwd_wait_handles = pp_comm_output.output_tensor_grad, pp_comm_output.bwd_wait_handles
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
else:
slave_cur_microbatch += 1
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: # 同步上个阶段最后一个slave前向send
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
else:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, fwd_wait_handles_recv = recv_forward(
tensor_shape, config, slave_chunk_id, async_op=True)
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
else:
firstFB_no_overlp = False
if not only_bwd:
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id)
if args.moe_fb_overlap:
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensor, _ = forward_step_helper(
fwd_model_chunk_id, fwd_microbatch, checkpoint_activations_microbatch)
else:
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
fwd_model_chunk_id,
data_iterator[fwd_model_chunk_id],
model[fwd_model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch
)
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
total_num_tokens += num_tokens.item()
output_tensors[fwd_model_chunk_id].append(output_tensor)
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch ==
master_microbatch_max)
if fwd_send_only:
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor
else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
if firstFB_no_overlp_handle is not None:
for req in firstFB_no_overlp_handle:
req.wait()
firstFB_no_overlp_handle = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: # 同步上个阶段最后一个slave前向send
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
# only run backward
else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
# swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
# Run cooldown phases
merged_input_tensors = []
merged_output_tensors = []
while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0:
if len(input_tensors[bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id))
if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[1 - bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id))
bwd_wait_handles_recv = None
for i in range(pp_size):
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if bwd_wait_handles_recv is not None:
for req in bwd_wait_handles_recv:
req.wait()
bwd_wait_handles_recv = None
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
if not args.dualpipe_no_dw_detach:
WeightGradStore.start_decouple()
if args.moe_fb_overlap:
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if not args.dualpipe_no_dw_detach:
WeightGradStore.end_decouple()
if i == pp_size - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
elif i >= schedule['cooldown'][rank][0] - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward(
tensor_shape, config, bwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
WeightGradStore.flush_chunk_grad()
if i >= schedule['cooldown'][rank][0] - 1:
WeightGradStore.pop_single()
for _ in range(schedule['cooldown'][rank][2] - 1):
WeightGradStore.pop_single()
assert WeightGradStore.weight_grad_queue.empty()
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if config.finalize_model_grads_func is not None and not forward_only:
# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute(config, embedding_module)
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config.finalize_model_grads_func(
model, total_num_tokens if config.calculate_per_token_loss else None
)
return forward_data_store
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment