Commit 8d5bae2a authored by dongcl's avatar dongcl
Browse files

add dualpipev_chunks to support dualpipev

parent e5f5eb4d
import torch
from functools import wraps
from typing import List, Optional
from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.enums import ModelType
from megatron.training.global_vars import get_args, get_timers
from megatron.training.utils import unwrap_model
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.transformer.module import fp32_to_float16, float16_to_fp32
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core import parallel_state
from megatron.core.distributed.finalize_model_grads import _allreduce_layernorm_grads
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules import get_dualpipe_chunk
def dualpipev_fp16forward(self, *inputs, **kwargs):
dualpipe_first_stage = mpu.is_pipeline_first_stage() and get_dualpipe_chunk() == 0
if dualpipe_first_stage:
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
dualpipe_last_stage = mpu.is_pipeline_first_stage() and get_dualpipe_chunk() == 1
if dualpipe_last_stage:
outputs = float16_to_fp32(outputs)
return outputs
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
assert model_type != ModelType.encoder_and_decoder, \
"dualpipev schedule not supported for model with both encoder and decoder"
model = []
args.dualpipev_first_chunk = True
first_model = model_provider_func(
pre_process=mpu.is_pipeline_first_stage(),
post_process=False
)
first_model.model_type = model_type
model.append(first_model)
args.dualpipev_first_chunk = False
second_model = model_provider_func(
pre_process=False,
post_process=mpu.is_pipeline_first_stage()
)
second_model.model_type = model_type
model.append(second_model)
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(
param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
if wrap_with_ddp:
config = get_model_config(model[0])
ddp_config = DistributedDataParallelConfig(
grad_reduce_in_fp32=args.accumulate_allreduce_grads_in_fp32,
overlap_grad_reduce=args.overlap_grad_reduce,
use_distributed_optimizer=args.use_distributed_optimizer,
check_for_nan_in_grad=args.check_for_nan_in_loss_and_grad,
bucket_size=args.ddp_bucket_size,
average_in_collective=args.ddp_average_in_collective)
model = [DDP(config,
ddp_config,
model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0))
for (model_chunk_idx, model_chunk) in enumerate(model)]
# Broadcast params from data parallel src rank to other data parallel ranks.
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
return model
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
rerun_state_machine = get_rerun_state_machine()
while rerun_state_machine.should_run_forward_backward(data_iterator):
# Set grad to zero.
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False)
should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_first_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0].keys():
numerator = 0
denominator = 0
for x in losses_reduced:
val = x[key]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if isinstance(val, tuple) or isinstance(val, list):
numerator += val[0]
denominator += val[1]
else:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
def get_num_layers_to_build(config: TransformerConfig) -> int:
num_layers_per_pipeline_rank = (
config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
num_layers_to_build = num_layers_per_pipeline_rank // 2
return num_layers_to_build
def _allreduce_embedding_grads_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if get_args().schedules_method == 'dualpipev':
# dualpipev no need to do embedding allreduce
# embedding and lm head are on save rank.
if not get_args().untie_embeddings_and_output_weights:
raise NotImplementedError
else:
return
else:
return fn(*args, **kwargs)
return wrapper
......@@ -21,8 +21,6 @@ from megatron.core import ModelParallelConfig
from megatron.core.pipeline_parallel.p2p_communication import _communicate
from megatron.core.pipeline_parallel.schedules import backward_step, set_current_microbatch, custom_backward, finish_embedding_wgrad_compute
from megatron.core.models.gpt import GPTModel
from mindspeed.core.pipeline_parallel.fb_overlap.gpt_model import gpt_model_backward
from mindspeed.core.pipeline_parallel.fb_overlap.transformer_layer import P2PCommParams
from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
......@@ -34,10 +32,10 @@ LOSS_BACKWARD_SCALE = torch.tensor(1.0)
_DUALPIPE_CHUNK = None
def set_dualpipe_chunk(chunkid):
def set_dualpipe_chunk(chunk_id):
"""set_dualpipe_chunk for fp16forward patch"""
global _DUALPIPE_CHUNK
_DUALPIPE_CHUNK = chunkid
_DUALPIPE_CHUNK = chunk_id
def get_dualpipe_chunk():
......@@ -48,7 +46,7 @@ def get_dualpipe_chunk():
raise AssertionError("_DUALPIPE_CHUNK is None")
def is_dualpipev_last_stgae(model_chunk_id):
def is_dualpipev_last_stage(model_chunk_id):
return parallel_state.is_pipeline_first_stage(ignore_virtual=True) and model_chunk_id == 1
......@@ -59,11 +57,11 @@ def send_forward(output_tensor: torch.Tensor, tensor_shape, config: ModelParalle
"""
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
return None
tensor_send_next = output_tensor
else:
if parallel_state.is_pipeline_first_stage():
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
return None
tensor_send_prev = output_tensor
......@@ -93,11 +91,11 @@ def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelPa
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_first_stage():
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
return None
tensor_send_prev = input_tensor_grad
else:
if parallel_state.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
return None
tensor_send_next = input_tensor_grad
......@@ -128,7 +126,10 @@ def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_i
else:
recv_next = True
if (parallel_state.is_pipeline_first_stage() and recv_prev) or (parallel_state.is_pipeline_last_stage() and recv_next):
if (
(parallel_state.is_pipeline_first_stage(ignore_virtual=True) and recv_prev)
or (parallel_state.is_pipeline_last_stage(ignore_virtual=True) and recv_next)
):
fwd_wait_handles = None
return None, fwd_wait_handles
else:
......@@ -163,7 +164,10 @@ def recv_backward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_
else:
recv_prev = True
if (parallel_state.is_pipeline_first_stage() and recv_prev) or (parallel_state.is_pipeline_last_stage() and recv_next):
if (
(parallel_state.is_pipeline_first_stage(ignore_virtual=True) and recv_prev)
or (parallel_state.is_pipeline_last_stage(ignore_virtual=True) and recv_next)
):
output_tensor_grad = None
bwd_wait_handles = None
return output_tensor_grad, bwd_wait_handles
......@@ -203,14 +207,14 @@ def send_forward_recv_forward(
recv_prev, recv_next = False, False
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if not parallel_state.is_pipeline_last_stage():
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
tensor_send_next = output_tensor
if not parallel_state.is_pipeline_first_stage():
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
recv_prev = True
if model_chunk_id == 1:
if not parallel_state.is_pipeline_first_stage():
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
tensor_send_prev = output_tensor
if not parallel_state.is_pipeline_last_stage():
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = True
if config.timers is not None:
......@@ -228,22 +232,23 @@ def send_forward_recv_forward(
config.timers('forward-send-forward-recv').stop()
if model_chunk_id == 0:
if not parallel_state.is_pipeline_first_stage():
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
return tensor_recv_prev, fwd_wait_handles
else:
return None, fwd_wait_handles
else:
if not parallel_state.is_pipeline_last_stage():
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
return tensor_recv_next, fwd_wait_handles
else:
return None, fwd_wait_handles
# TODO (dongcl)
def send_forward_recv_slave_forward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig,
model_chunk_id,
fwd_model_chunk_id,
async_op=False,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
......@@ -251,13 +256,13 @@ def send_forward_recv_slave_forward(
"""
recv_prev, recv_next = False, False
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_last_stage():
if fwd_model_chunk_id == 0:
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
return None, None
tensor_send_next = output_tensor
recv_next = True
if model_chunk_id == 1:
if parallel_state.is_pipeline_first_stage():
if fwd_model_chunk_id == 1:
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
return None, None
tensor_send_prev = output_tensor
recv_prev = True
......@@ -275,7 +280,49 @@ def send_forward_recv_slave_forward(
if config.timers is not None:
config.timers('forward-send-slave-forward-recv').stop()
if model_chunk_id == 0:
if fwd_model_chunk_id == 0:
return tensor_recv_next, fwd_wait_handles
else:
return tensor_recv_prev, fwd_wait_handles
def send_backward_recv_slave_backward(
input_tensor_grad: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig,
fwd_model_chunk_id,
async_op=False,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
tensor_send_next, tensor_send_prev = None, None
if fwd_model_chunk_id == 0:
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
return None, None
tensor_send_next = input_tensor_grad
recv_next = True
if fwd_model_chunk_id == 1:
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
return None, None
tensor_send_prev = input_tensor_grad
recv_prev = True
if config.timers is not None:
config.timers('forward-send-slave-forward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
wait_on_reqs=(not async_op),
config=config,
)
if config.timers is not None:
config.timers('forward-send-slave-forward-recv').stop()
if fwd_model_chunk_id == 0:
return tensor_recv_next, fwd_wait_handles
else:
return tensor_recv_prev, fwd_wait_handles
......@@ -320,38 +367,6 @@ def generate_dualpipev_schedule(pp_size, num_microbatches):
return schedule_all_stages
def pretrain_gpt_forward_step_dualpipe(data_iterator, model: GPTModel, extra_block_kwargs=None):
from megatron.training import get_timers
from functools import partial
from pretrain_gpt import get_batch, loss_func
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
if extra_block_kwargs is not None:
# excute forward backward overlaping
output_tensor, model_graph, pp_comm_output = \
model(tokens, position_ids, attention_mask, labels=labels,
extra_block_kwargs=extra_block_kwargs)
return (output_tensor, model_graph, pp_comm_output), partial(loss_func, loss_mask)
else:
output_tensor, model_graph = model(
tokens, position_ids, attention_mask, labels=labels)
return (output_tensor, model_graph), partial(loss_func, loss_mask)
def forward_step_no_model_graph(
forward_step_func,
model_chunk_id,
......@@ -395,18 +410,20 @@ def forward_step_no_model_graph(
)
num_tokens = torch.tensor(0, dtype=torch.int)
if is_dualpipev_last_stgae:
if is_dualpipev_last_stage(model_chunk_id):
if not collect_non_loss_data:
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
else:
......@@ -417,251 +434,36 @@ def forward_step_no_model_graph(
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.tensor(1.0)
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
def backward_step_with_model_graph(input_tensor, output_tensor, output_tensor_grad, model_type, config, model_graph=None):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if config.timers is not None:
config.timers('backward-compute', log_level=2).start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if output_tensor_grad[0] is None and config.grad_scale_func is not None and model_graph is None:
output_tensor[0] = config.grad_scale_func(output_tensor[0])
if config.deallocate_pipeline_outputs:
if model_graph is None:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
layer_output_grad = gpt_model_backward(
output_tensor_grad[0], model_graph)
else:
torch.autograd.backward(
output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = []
if model_graph is not None:
input_tensor_grad.append(layer_output_grad)
if config.calculate_per_token_loss:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale)
else:
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
if config.timers is not None:
config.timers('backward-compute').stop()
return input_tensor_grad
def forward_step_with_model_graph(
forward_step_func,
model_chunk_id,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
extra_block_kwargs=None,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable): The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally.
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator): The data iterator.
model (nn.Module): The model to perform the forward step on.
num_microbatches (int): The number of microbatches.
input_tensor (Tensor or list[Tensor]): The input tensor(s) for the forward step.
forward_data_store (list): The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object): The configuration object.
collect_non_loss_data (bool, optional): Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional): The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional): Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional): The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = pretrain_gpt_forward_step_dualpipe(
data_iterator, model, extra_block_kwargs)
else:
output_tensor, loss_func = pretrain_gpt_forward_step_dualpipe(
data_iterator, model, checkpoint_activations_microbatch, extra_block_kwargs
)
num_tokens = torch.tensor(0, dtype=torch.int)
if is_dualpipev_last_stgae(model_chunk_id):
if not collect_non_loss_data:
next_info = None
if isinstance(output_tensor, tuple):
# use pp overlaping,
if len(output_tensor) == 2:
output_tensor, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor, model_graph, next_info = output_tensor
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
output_tensor = (output_tensor, model_graph, next_info) if next_info is not None else (
output_tensor, model_graph)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(LOSS_BACKWARD_SCALE)
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.tensor(1.0)
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# If T5 model and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if (
......@@ -713,11 +515,11 @@ def forward_backward_pipelining_with_cutinhalf(
set_shared_embedding_from_dual_chunk(model[0], model[1])
assert (
isinstance(model, list) and len(model) == 2
), 'Dualpipe Schedule only support chunk model for two consecutive chunks'
), 'Dualpipe Schedule expects two model chunks'
assert (
isinstance(data_iterator, list) and len(data_iterator) == 2
), 'Dualpipe Schedule only support two data_iterators'
), 'Dualpipe Schedule expects two data_iterators'
config = get_model_config(model[0])
config.batch_p2p_comm = False
......@@ -727,8 +529,7 @@ def forward_backward_pipelining_with_cutinhalf(
embedding_module = clear_embedding_activation_buffer(config, model)
if config.timers is not None:
config.timers('forward-backward',
log_level=1).start(barrier=config.barrier_with_L1_time)
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
# Disable async grad reductions
no_sync_func = config.no_sync_func
......@@ -783,97 +584,30 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch = None
def forward_step_helper(model_chunk_id, current_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=False, extra_block_kwargs=None):
input_tensor = recv_forward(tensor_shape, config, master_chunk_id)[0]
input_tensor = input_tensors[model_chunk_id][-1][1]
output_tensor, num_tokens = forward_step_with_model_graph(
fwd_wait_handles_warmup = None
# Run warmup forward passes
for i in range(schedule['warmup'][rank]):
output_tensor_warmup, num_tokens = forward_step_no_model_graph(
forward_step_func,
model_chunk_id,
data_iterator[model_chunk_id],
model[model_chunk_id],
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch,
current_microbatch=current_microbatch,
extra_block_kwargs=extra_block_kwargs
is_first_microbatch=(i == 0),
current_microbatch=master_cur_microbatch
)
if isinstance(output_tensor, tuple):
if len(output_tensor) == 2:
output_tensor_, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor_, model_graph, pp_comm_output = output_tensor
if is_dualpipev_last_stgae(model_chunk_id):
logits_inputs.append(
model_graph.layer_graphs[-1].unperm2_graph[1])
model_graphs[model_chunk_id].append(model_graph)
else:
output_tensor_ = output_tensor
output_tensors[model_chunk_id].append(output_tensor_)
if extra_block_kwargs is not None:
input_tensors[1 - model_chunk_id].pop(0)
output_tensors[1 - model_chunk_id].pop(0)
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def check_pipeline_stage(model_chunk_id, fwd_send_only):
send_next, recv_next, send_prev, recv_prev = True, True, True, True
if parallel_state.is_pipeline_first_stage():
send_prev, recv_prev = False, False
if parallel_state.is_pipeline_last_stage():
send_next, recv_next = False, False
if model_chunk_id == 0:
return P2PCommParams(send_next=send_next, recv_next=not fwd_send_only and recv_next), P2PCommParams(send_next=send_next, recv_next=recv_next)
else:
return P2PCommParams(send_prev=send_prev, recv_prev=not fwd_send_only and recv_prev), P2PCommParams(send_prev=send_prev, recv_prev=recv_prev)
input_tensor = recv_forward(tensor_shape, config, master_chunk_id)[0]
fwd_wait_handles_warmup = None
# Run warmup forward passes
for i in range(schedule['warmup'][rank]):
if args.moe_fb_overlap:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensor_warmup, _ = forward_step_helper(master_chunk_id, master_cur_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=(i == 0))
else:
output_tensor_warmup, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=(i == 0),
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
master_cur_microbatch += 1
......@@ -899,45 +633,39 @@ def forward_backward_pipelining_with_cutinhalf(
req.wait()
fwd_wait_handles = None
is_first_microbatch = parallel_state.is_pipeline_last_stage() and (i == 0)
is_first_microbatch = parallel_state.is_pipeline_last_stage(ignore_virtual=True) and (i == 0)
set_dualpipe_chunk(master_chunk_id)
if args.moe_fb_overlap:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensor, _ = forward_step_helper(master_chunk_id, master_cur_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch)
else:
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch,
current_microbatch=master_cur_microbatch
)
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch,
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
master_cur_microbatch += 1
if not parallel_state.is_pipeline_last_stage() and fwd_wait_handles_send is not None:
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and fwd_wait_handles_send is not None:
for req in fwd_wait_handles_send:
req.wait()
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if parallel_state.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
input_tensor_slave_chunk = output_tensor
input_tensor, fwd_wait_handles = recv_forward(
......@@ -964,31 +692,24 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_slave_chunk = None
set_dualpipe_chunk(slave_chunk_id)
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch,
)
if args.moe_fb_overlap:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensor_slave_chunk, _ = forward_step_helper(
slave_chunk_id, slave_cur_microbatch, checkpoint_activations_microbatch)
else:
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch,
)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
......@@ -997,7 +718,7 @@ def forward_backward_pipelining_with_cutinhalf(
firstFB_no_overlp = False
firstFB_no_overlp_handle = None
# last rank not overlap first F&B
if parallel_state.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
firstFB_no_overlp = True
output_tensor_grad_bwd, firstFB_no_overlp_handle = recv_backward(
tensor_shape, config, slave_chunk_id, async_op=True)
......@@ -1008,7 +729,7 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True)
if not parallel_state.is_pipeline_last_stage():
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
output_tensor_send = output_tensor
fwd_wait_handles_send = send_forward(
output_tensor_send, tensor_shape, config, master_chunk_id, async_op=True)
......@@ -1024,32 +745,12 @@ def forward_backward_pipelining_with_cutinhalf(
WeightGradStore.start_decouple()
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(slave_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[slave_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
model_graph = model_graphs[slave_chunk_id].pop(0)
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
WeightGradStore.end_decouple()
......@@ -1084,31 +785,24 @@ def forward_backward_pipelining_with_cutinhalf(
# 1F: Forward pass
set_dualpipe_chunk(slave_chunk_id)
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch
)
if args.moe_fb_overlap:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensor_slave_chunk, _ = forward_step_helper(
slave_chunk_id, slave_cur_microbatch, checkpoint_activations_microbatch)
else:
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch
)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
......@@ -1129,277 +823,110 @@ def forward_backward_pipelining_with_cutinhalf(
if fwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch == slave_microbatch_max:
only_bwd = True
if args.moe_fb_overlap and not firstFB_no_overlp:
if not only_bwd:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
if fwd_wait_handles_recv is not None:
for req in fwd_wait_handles_recv:
req.wait()
fwd_wait_handles_recv = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if not parallel_state.is_pipeline_last_stage() or fwd_model_chunk_id == master_chunk_id:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id)
fwd_send_only = False
if fwd_model_chunk_id == slave_chunk_id and master_cur_microbatch == master_microbatch_max:
fwd_send_only = True
extra_block_kwargs = {}
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
extra_block_kwargs.setdefault(
'bwd_model_grad', input_tensor_grad)
else:
extra_block_kwargs.setdefault(
'bwd_model_grad', output_tensor_grad_bwd)
fwd_pp_comm_params, bwd_pp_comm_params = check_pipeline_stage(
fwd_model_chunk_id, fwd_send_only)
fwd_pp_comm_params.config, bwd_pp_comm_params.config = config, config
fwd_pp_comm_params.tensor_shape, bwd_pp_comm_params.tensor_shape = tensor_shape, tensor_shape
if not only_bwd:
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id)
extra_block_kwargs.setdefault(
'bwd_model_graph', model_graphs[bwd_model_chunk_id].pop(0))
extra_block_kwargs.setdefault(
'pp_comm_params', fwd_pp_comm_params)
extra_block_kwargs.setdefault(
'bwd_pp_comm_params', bwd_pp_comm_params)
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
fwd_model_chunk_id,
data_iterator[fwd_model_chunk_id],
model[fwd_model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch
)
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
total_num_tokens += num_tokens.item()
output_tensors[fwd_model_chunk_id].append(output_tensor)
output_tensor, model_graph, pp_comm_output = forward_step_helper(fwd_model_chunk_id, fwd_microbatch, checkpoint_activations_microbatch,
extra_block_kwargs=extra_block_kwargs)
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch ==
master_microbatch_max)
if fwd_send_only:
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor
output_tensor_grad_bwd = pp_comm_output.input_tensor_grad
else:
input_tensor, fwd_wait_handles = pp_comm_output.input_tensor, pp_comm_output.fwd_wait_handles
output_tensor_grad_bwd, bwd_wait_handles = pp_comm_output.output_tensor_grad, pp_comm_output.bwd_wait_handles
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
else:
slave_cur_microbatch += 1
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: # 同步上个阶段最后一个slave前向send
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
else:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if firstFB_no_overlp_handle is not None:
for req in firstFB_no_overlp_handle:
req.wait()
firstFB_no_overlp_handle = None
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, fwd_wait_handles_recv = recv_forward(
tensor_shape, config, slave_chunk_id, async_op=True)
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: # 同步上个阶段最后一个slave前向send
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
# only run backward
else:
firstFB_no_overlp = False
if not only_bwd:
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id)
if args.moe_fb_overlap:
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensor, _ = forward_step_helper(
fwd_model_chunk_id, fwd_microbatch, checkpoint_activations_microbatch)
else:
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
fwd_model_chunk_id,
data_iterator[fwd_model_chunk_id],
model[fwd_model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch
)
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
total_num_tokens += num_tokens.item()
output_tensors[fwd_model_chunk_id].append(output_tensor)
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch ==
master_microbatch_max)
if fwd_send_only:
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor
else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
if firstFB_no_overlp_handle is not None:
for req in firstFB_no_overlp_handle:
req.wait()
firstFB_no_overlp_handle = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: # 同步上个阶段最后一个slave前向send
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
# only run backward
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
# swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
......@@ -1438,16 +965,9 @@ def forward_backward_pipelining_with_cutinhalf(
if not args.dualpipe_no_dw_detach:
WeightGradStore.start_decouple()
if args.moe_fb_overlap:
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if not args.dualpipe_no_dw_detach:
WeightGradStore.end_decouple()
......@@ -1465,7 +985,7 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
WeightGradStore.flush_chunk_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment