Commit 770fa304 authored by dongcl's avatar dongcl
Browse files

修改mtp

parent 8096abd4
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from einops import rearrange
import torch
from megatron.training import get_args
from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
from mindspeed.ops.gmm import GMMFunction
from mindspeed.model.transformer import should_recompute_activation
from mindspeed.ops.npu_groupmatmul_add import npu_groupmatmul_add_fp32
def get_gmm_weight_grad(inputs, grad_out, group_list, group_list_data_type, weight_param, weight_tensor):
if WeightGradStore.is_decoupleBlock:
WeightGradStore.put(
[inputs, group_list, group_list_data_type],
grad_out,
weight_param,
sequence_parallel=False,
in_row=False,
)
if hasattr(weight_param, 'grad_added_to_main_grad') and get_args().overlap_grad_reduce:
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
shape = list(weight_tensor.shape)
shape[1], shape[2] = shape[2], shape[1]
weight_param.skip_grad_accum = True
grad_weights = None
else:
if get_args().gemm_gradient_accumulation_fusion:
npu_groupmatmul_add_fp32(inputs, grad_out, group_list, weight_param.main_grad)
if hasattr(weight_param, 'grad_added_to_main_grad'):
shape = list(weight_tensor.shape)
shape[1], shape[2] = shape[2], shape[1]
if getattr(weight_tensor, 'zero_out_wgrad', False):
grad_weights = torch.zeros(
shape,
dtype=inputs.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
grad_weights = torch.empty(
shape,
dtype=inputs.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
weight_param.grad_added_to_main_grad = True
else:
grad_weights = None
else:
grad_weights = GMMFunction.builder.load().npu_gmm([inputs.t()], [grad_out], [], group_list, 2,
group_list_data_type)[0]
return grad_weights
class GroupedMatmulWithWeightGradDetach(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, weight_tensor, weight_param, group_list, in_row=False):
mm_out = GMMFunction.builder.load().npu_gmm([inputs], [weight_tensor], [], group_list, 0, 0)[0]
ctx.save_for_backward(inputs, weight_tensor, group_list)
ctx.weight_param = weight_param
ctx.in_row = in_row
return mm_out
@staticmethod
def backward(ctx, *grad_outs):
grad_out = grad_outs[0]
inputs, weight_tensor, group_list = ctx.saved_tensors
weight_param = ctx.weight_param
weight_tensor = rearrange(weight_tensor, 'n h f -> n f h')
grad_inputs = \
GMMFunction.builder.load().npu_gmm([grad_out], [weight_tensor], [], group_list, 0, 0)[0]
grad_weights = get_gmm_weight_grad(inputs, grad_out, group_list, 0, weight_param,
weight_tensor)
return grad_inputs, grad_weights, None, None, None
def npu_gmm_with_detach(inputs, weight_tensor, weight_param, bias=None, group_list=None):
return GroupedMatmulWithWeightGradDetach.apply(inputs, weight_tensor, weight_param, group_list)
def group_mlp_forward_detach(self, permuted_local_hidden_states, tokens_per_expert):
args = get_args()
is_recompute_activation = args.moe_zero_memory == 'level0' or should_recompute_activation(self.layer_number)
if permuted_local_hidden_states.nelement() != 0:
group_list = torch.cumsum(tokens_per_expert, dim=0)
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = npu_gmm_with_detach(permuted_local_hidden_states, w1, self.weight1, bias=None, group_list=group_list)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = npu_gmm_with_detach(intermediate_parallel, w2, self.weight2, bias=None, group_list=group_list)
if is_recompute_activation:
intermediate_parallel.untyped_storage().resize_(0)
else:
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
# Make sure parameters still have gradients when no tokens are routed to this set of experts.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
fc1_output = torch.matmul(permuted_local_hidden_states, w1)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = torch.matmul(intermediate_parallel, w2)
if is_recompute_activation:
intermediate_parallel.untyped_storage().resize_(0)
return (fc2_output, fc1_output, intermediate_parallel), None
\ No newline at end of file
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import os
import warnings
from typing import Any, Callable, List, Optional
import torch
import torch.distributed
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.parameter import Parameter
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
linear_with_grad_accumulation_and_async_allreduce,
linear_with_frozen_weight
)
from megatron.core.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim
)
from megatron.core.tensor_parallel.utils import VocabUtility, divide, split_tensor_along_last_dim
from megatron.core.utils import (
make_tp_sharded_tensor_for_checkpoint,
prepare_input_tensors_for_wgrad_compute
)
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.training import get_args
from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
def linear_backward_wgrad_detach(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_output_buffer = ctx.grad_output_buffer
wgrad_deferral_limit = ctx.wgrad_deferral_limit
wgrad_compute = True
if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output)
wgrad_compute = False
if wgrad_compute:
if ctx.sequence_parallel and not WeightGradStore.is_decoupleBlock:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(
dim_size, input_.dtype, "mpu"
)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input_, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input = all_gather_buffer
else:
total_input = input_
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel and wgrad_compute and not WeightGradStore.is_decoupleBlock:
handle.wait()
if wgrad_compute and not WeightGradStore.is_decoupleBlock:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input
)
if ctx.allreduce_dgrad:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if ctx.sequence_parallel:
assert not ctx.allreduce_dgrad
dim_size = list(input_.size())
sub_grad_input = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device(), requires_grad=False
)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(
sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
if WeightGradStore.is_decoupleBlock:
# TODO: remove clone under MLA setting
WeightGradStore.put(
total_input.clone().detach(),
grad_output.clone().detach(),
weight,
ctx.sequence_parallel,
in_row=not ctx.sequence_parallel
)
if hasattr(weight, 'grad_added_to_main_grad') and get_args().overlap_grad_reduce:
weight.skip_grad_accum = True
grad_weight = None
else:
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
if weight.main_grad.dtype == torch.float32:
from mindspeed.ops.npu_matmul_add import npu_matmul_add_fp32
npu_matmul_add_fp32(total_input, grad_output, weight.main_grad)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
if hasattr(weight, 'grad_added_to_main_grad'):
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
if getattr(weight, 'zero_out_wgrad', False):
grad_weight = torch.zeros(
weight.main_grad.shape,
dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
grad_weight = torch.empty(
weight.main_grad.shape,
dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
weight.grad_added_to_main_grad = True
else:
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
# Need to return None's as gradient has to flow for all the input arguments
# provided during forward
return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None
if ctx.allreduce_dgrad:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None, None
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce"""
@staticmethod
@custom_fwd
def forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel,
grad_output_buffer,
shared_expert,
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel = sequence_parallel
ctx.grad_output_buffer = grad_output_buffer
ctx.shared_expert = shared_expert
if sequence_parallel:
if shared_expert:
from mindspeed.core.transformer.moe.moe_utils import AG_SHARED_EXPERTS_INPUTS
ag_shared_experts_inputs = AG_SHARED_EXPERTS_INPUTS.pop(0)
if isinstance(ag_shared_experts_inputs, tuple):
ag_shared_experts_inputs, handle = ag_shared_experts_inputs
handle.wait()
total_input = ag_shared_experts_inputs
else:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group()
)
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_output_buffer = ctx.grad_output_buffer
wgrad_compute = True
if grad_output_buffer is not None:
grad_output_buffer.append(grad_output)
wgrad_compute = False
if wgrad_compute:
if ctx.sequence_parallel and not WeightGradStore.is_decoupleBlock:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(
dim_size, input.dtype, "mpu"
)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel and wgrad_compute and not WeightGradStore.is_decoupleBlock:
handle.wait()
if wgrad_compute and not WeightGradStore.is_decoupleBlock:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input
)
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(
dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(
sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
if WeightGradStore.is_decoupleBlock:
# TODO: remove clone under MLA setting
WeightGradStore.put(
total_input.clone().detach(),
grad_output.clone().detach(),
weight,
ctx.sequence_parallel,
in_row=not ctx.sequence_parallel
)
if hasattr(weight, 'grad_added_to_main_grad') and get_args().overlap_grad_reduce:
weight.skip_grad_accum = True
grad_weight = None
else:
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
if weight.main_grad.dtype == torch.float32:
from mindspeed.ops.npu_matmul_add import npu_matmul_add_fp32
npu_matmul_add_fp32(total_input, grad_output, weight.main_grad)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
if hasattr(weight, 'grad_added_to_main_grad'):
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
if getattr(weight, 'zero_out_wgrad', False):
grad_weight = torch.zeros(
weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
grad_weight = torch.empty(
weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
weight.grad_added_to_main_grad = True
else:
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
# Need to return None's as gradient has to flow for all the input arguments
# provided during forward
return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
shared_expert: bool = False
) -> torch.Tensor:
args = [
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel,
grad_output_buffer,
shared_expert,
]
if not linear_with_grad_accumulation_and_async_allreduce.warned:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if sequence_parallel:
warnings.warn(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce.warned = True
if async_grad_allreduce:
warnings.warn(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce.warned = True
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module):
def __init__(
self,
input_size,
output_size,
*,
config: ModelParallelConfig,
init_method: Callable,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
shared_expert: bool = False
):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.embedding_activation_buffer = embedding_activation_buffer
self.grad_output_buffer = grad_output_buffer
self.config = config
self.shared_expert = shared_expert
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition, self.input_size, dtype=config.params_dtype
)
)
if config.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight,
init_method,
partition_dim=0,
stride=stride,
expert_parallel=(self.is_expert and self.expert_parallel),
)
setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
else:
self.weight = None
self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
config.async_tensor_model_parallel_allreduce and world_size > 1
)
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and world_size <= 1:
self.sequence_parallel = False
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel:
raise RuntimeError(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel` "
"cannot be enabled at the same time."
)
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
self.explicit_expert_comm = self.is_expert and (
self.sequence_parallel or self.expert_parallel
)
# Hook adding a default empty _extra_state for state dict
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
f'{prefix}_extra_state'
)
)
def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional): weight tensor to use, compulsory when
skip_weight_param_allocation is True.
Returns:
- output
- bias
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
else:
# Check the weight passed in is the correct shape
expected_shape = (self.output_size_per_partition, self.input_size)
if weight.shape != expected_shape:
raise RuntimeError(
f"supplied weight's shape is {tuple(weight.shape)}, "
f"not {expected_shape} as expected"
)
if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context == True:
assert (
self.config.cpu_offloading == False
), "CPU Offloading cannot be enabled while using non-TE modules"
bias = self.bias if not self.skip_bias_add else None
if (
self.async_tensor_model_parallel_allreduce
or self.sequence_parallel
or self.explicit_expert_comm
):
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer.append(input_parallel)
# Matrix multiply.
if not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
output_parallel = self._forward_impl(
input=input_parallel,
weight=weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False
if self.explicit_expert_comm
else self.async_tensor_model_parallel_allreduce,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=self.grad_output_buffer
if self.config.defer_embedding_wgrad_compute
else None,
shared_expert=self.shared_expert
)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 0, bias sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
def set_extra_state(self, state: Any):
""" Extra state is ignored """
def get_extra_state(self) -> None:
""" Keep compatibility with TE state dict. """
return None
class RowParallelLinear(torch.nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
shared_expert: bool = False
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.config = config
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
self.sequence_parallel = config.sequence_parallel
self.shared_expert = shared_expert
if self.sequence_parallel and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size, self.input_size_per_partition, dtype=config.params_dtype
)
)
if config.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight,
init_method,
partition_dim=1,
stride=stride,
expert_parallel=(self.is_expert and self.expert_parallel),
)
setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
if bias:
if config.use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype))
else:
self.bias = Parameter(
torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
else:
self.register_parameter('bias', None)
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
self.explicit_expert_comm = self.is_expert and (
self.sequence_parallel or self.expert_parallel
)
# Hook adding a default empty _extra_state for state dict
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
f'{prefix}_extra_state'
)
)
def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context == True:
assert (
self.config.cpu_offloading == False
), "CPU Offloading cannot be enabled while using non-TE modules"
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if not self.weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel=False,
)
# All-reduce across all the partitions.
if self.explicit_expert_comm or self.shared_expert:
assert self.skip_bias_add
output_ = output_parallel
elif self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = (output_ + self.bias) if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 1, bias not sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
def set_extra_state(self, state: Any):
""" Extra state is ignored """
def get_extra_state(self) -> None:
""" Keep compatibility with TE state dict. """
return None
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.moe.moe_utils import permute, unpermute
from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel
from megatron.core.utils import make_viewless_tensor
from megatron.training import get_args
from mindspeed.core.transformer.moe.unpermute_without_activation import UnpermuteWithoutActivation
def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
# use 0.7.0 implement for better performance
num_local_tokens_per_expert = torch.histc(
indices, bins=self.num_experts, min=0, max=self.num_experts
)
ep_size = self.config.expert_model_parallel_size
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_extended_ep_size = ep_size * tp_size
if self.drop_and_pad:
self.capacity = self.probs.size(1)
num_tokens_per_local_expert = torch.full(
(self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long,
device=torch.cuda.current_device()
)
return num_tokens_per_local_expert
elif self.config.moe_expert_capacity_factor is not None:
# Token drop but no pad. A synchronization is needed before the first
# permutation to get the `num_out_tokens` CPU value.
self.num_out_tokens = num_local_tokens_per_expert.sum().to(
torch.device("cpu"), non_blocking=True
)
self.cuda_sync_point = "before_permutation_1"
elif tp_extended_ep_size > 1:
# Token dropless and enable ep. A synchronization is needed before expert parallel
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
self.cuda_sync_point = "before_ep_alltoall"
else:
# Token dropless and no ep. A synchronization is needed before the token_permutation()
# function returns to get the `tokens_per_expert` CPU value.
self.cuda_sync_point = "before_finish"
if tp_extended_ep_size > 1:
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (
num_local_tokens_per_expert.reshape(tp_extended_ep_size, self.num_local_experts)
.sum(axis=1)
.to(torch.device("cpu"), non_blocking=True)
.numpy()
)
num_global_tokens_per_expert = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
num_local_tokens_per_expert
).reshape(tp_extended_ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
]
self.output_splits = (
self.num_global_tokens_per_local_expert
.sum(axis=-1)
.to(torch.device("cpu"), non_blocking=True)
.numpy()
)
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
else:
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
-1, self.num_experts
)
num_tokens_per_local_expert = num_local_tokens_per_expert
if self.num_local_experts > 1:
# No further synchronization is needed because torch.repeat_interleave() calls stream
# synchronization internally when the `output_size` parameter is not provided.
self.cuda_sync_point = "no_sync"
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
)
return num_tokens_per_local_expert
def alltoall_token_perm1(
self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
):
self.hidden_shape = hidden_states.shape
self.probs = probs
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert indices.dim() == 2, "Expected 2D tensor for indices"
tokens_per_expert = preprocess(self, indices)
# Flatten the input tensor
# hidden_states: [S/TP, B, H] -> [S*B/TP, H]
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Permutation 1: input to AlltoAll input
self.hiddden_shape_before_permute = hidden_states.shape
if self.cuda_sync_point == "before_permutation_1":
torch.cuda.current_stream().synchronize()
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
indices,
num_out_tokens=self.num_out_tokens,
padded_mode=self.drop_and_pad,
)
# Perform expert parallel AlltoAll communication
if self.cuda_sync_point == "before_ep_alltoall":
torch.cuda.current_stream().synchronize()
return permutated_local_input_tokens, tokens_per_expert
def alltoall_token_perm2(self, global_input_tokens):
# Permutation 2: AlltoAll output to expert input if num_local_experts > 1
if self.num_local_experts > 1:
if not self.drop_and_pad:
global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
global_input_tokens, self.global_input_tokens_local_experts_indices
)
else:
global_input_tokens = global_input_tokens.reshape(
self.ep_size, self.num_local_experts, self.capacity, -1
)
global_input_tokens = (
global_input_tokens.transpose(0, 1)
.reshape(self.num_local_experts * self.ep_size * self.capacity, -1)
.contiguous()
)
if self.cuda_sync_point == "before_finish":
torch.cuda.current_stream().synchronize()
return global_input_tokens
def alltoall_token_unperm1(
self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None,
):
"""
Reverse the token permutation to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
# Unpermutation 2: expert output to AlltoAll input
if self.num_local_experts > 1:
if not self.drop_and_pad:
hidden_states = unpermute(
hidden_states,
self.reversed_global_input_permutation_mapping,
)
else:
hidden_states = hidden_states.reshape(
self.num_local_experts, self.ep_size, self.capacity, -1
)
hidden_states = (
hidden_states.transpose(0, 1)
.reshape(self.ep_size * self.num_local_experts * self.capacity, -1)
.contiguous()
)
return hidden_states
def alltoall_token_unperm2(self, permutated_local_input_tokens, probs=None):
# Unpermutation 1: AlltoAll output to output
probs = probs if probs is not None else self.probs
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
probs=probs,
padded_mode=self.drop_and_pad,
restore_shape=self.hiddden_shape_before_permute,
)
# Reshape the output tensor
output = output.view(self.hidden_shape)
output = make_viewless_tensor(
inp=output, requires_grad=output.requires_grad, keep_graph=True
)
return output, None
import torch
from torch.autograd.variable import Variable
from megatron.core.pipeline_parallel import p2p_communication
def detach_tensor(tensor, checkpoint_forward=False):
if checkpoint_forward:
return tensor
if tensor is None:
return None
detached_tensor = tensor.detach()
detached_tensor.requires_grad = True
return detached_tensor
def run_graph_backward(graph, output_tensor_grad=None, keep_graph=False, keep_grad=False):
grad_tensor = output_tensor_grad
if output_tensor_grad is None and graph[1] is not None and graph[1].grad is not None:
grad_tensor = graph[1].grad
Variable._execution_engine.run_backward(
tensors=(graph[0],),
grad_tensors=(grad_tensor,),
keep_graph=False,
create_graph=False,
inputs=tuple(),
allow_unreachable=True,
accumulate_grad=True,
)
if not keep_graph:
graph[0].untyped_storage().resize_(0)
if not keep_grad:
grad_tensor.untyped_storage().resize_(0)
class NoopLayerGraph:
def __init__(self, layer_input, layer_output, layer, checkpointed=False):
self.layer_input = layer_input
if not checkpointed:
self.unperm2_graph = (layer_output, None)
else:
self.unperm2_graph = (None, None)
self.checkpointed = checkpointed
self.layer = layer
def record_layer_inputs(self, *args):
self.layer_inputs = args
class LayerGraph:
def __init__(self, saved_graph_and_graph_inputs, recompute_needed_tensors, input_splits, output_splits, layer, checkpointed=False):
if not checkpointed:
self.attn_graph = saved_graph_and_graph_inputs[0]
self.pre_mlp_layernorm_graph = saved_graph_and_graph_inputs[1]
self.router_graph = saved_graph_and_graph_inputs[2]
self.perm1_graph = saved_graph_and_graph_inputs[3]
self.perm_a2a_graph = saved_graph_and_graph_inputs[4]
self.perm2_graph = saved_graph_and_graph_inputs[5]
self.grouped_mlp_graph = saved_graph_and_graph_inputs[6]
self.unperm1_graph = saved_graph_and_graph_inputs[7]
self.unperm_a2a_graph = saved_graph_and_graph_inputs[8]
self.unperm2_graph = saved_graph_and_graph_inputs[9]
self.shared_experts_graph = saved_graph_and_graph_inputs[10]
else:
self.unperm2_graph = (None, None)
self.layer_input = saved_graph_and_graph_inputs[-1]
self.recompute_needed_tensors = recompute_needed_tensors
self.input_splits = input_splits
self.output_splits = output_splits
self.checkpointed = checkpointed
self.layer = layer
self.is_moe_layer = hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts')
def record_layer_inputs(self, *args):
self.layer_inputs = args
class P2PCommParams:
tensor_shape = None
config = None
def __init__(self, send_next=False, send_prev=False, recv_next=False, recv_prev=False):
self.send_next = send_next
self.send_prev = send_prev
self.recv_next = recv_next
self.recv_prev = recv_prev
def __str__(self):
return f'send next:{self.send_next} send_prev:{self.send_prev} recv_next:{self.recv_next} recv_prev:{self.recv_prev}'
class P2PCommOutput:
def __init__(self, input_tensor=None, output_tensor_grad=None, fwd_wait_handles=None, bwd_wait_handles=None, input_tensor_grad=None):
self.input_tensor = input_tensor
self.fwd_wait_handles = fwd_wait_handles
self.output_tensor_grad = output_tensor_grad
self.bwd_wait_handles = bwd_wait_handles
self.input_tensor_grad = input_tensor_grad
def is_p2p_comm_needed(pp_comm_params: P2PCommParams):
return pp_comm_params is not None and \
(pp_comm_params.send_next or pp_comm_params.send_prev or pp_comm_params.recv_next or pp_comm_params.recv_prev)
def p2p_comm_helper(comm_params: P2PCommParams, tensor_tosend):
assert not (comm_params.send_next and comm_params.send_prev)
assert not (comm_params.recv_next and comm_params.recv_prev)
tensor_send_next = None
if comm_params.send_next:
tensor_send_next = tensor_tosend
tensor_send_prev = None
if comm_params.send_prev:
tensor_send_prev = tensor_tosend
tensor_recv_prev, tensor_recv_next, p2p_handles = p2p_communication._communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=comm_params.recv_prev,
recv_next=comm_params.recv_next,
tensor_shape=comm_params.tensor_shape,
wait_on_reqs=False,
config=comm_params.config
)
if comm_params.recv_next:
return tensor_recv_next, p2p_handles
elif comm_params.recv_prev:
return tensor_recv_prev, p2p_handles
else:
return None, p2p_handles
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import operator
import queue
from functools import reduce
import torch
import torch_npu
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
from megatron.training import get_args
from mindspeed.ops.gmm import GMMFunction
from mindspeed.ops.npu_groupmatmul_add import npu_groupmatmul_add_fp32
def gather(input_slice, stream):
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input_slice.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = torch.empty(
dim_size, dtype=input_slice.dtype, device=torch.cuda.current_device(), requires_grad=False
)
handle = None
forward_event = torch.npu.Event()
forward_event.record()
with torch.no_grad():
with torch_npu.npu.stream(stream):
stream.wait_event(forward_event)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input_slice, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
return all_gather_buffer, handle
class WeightGradStore:
cache = []
weight_grad_queue = queue.Queue()
store_grad_cache = []
grad_store = []
gather_stream = None
is_decoupleBlock = False
@classmethod
def put(cls, total_input, grad_output, weight, sequence_parallel, in_row=False):
cls.cache.append((total_input, grad_output, weight, sequence_parallel, in_row))
@classmethod
def flush_chunk_grad(cls):
cls.weight_grad_queue.put(cls.cache)
cls.cache = []
@classmethod
def start_decouple(cls):
cls.is_decoupleBlock = True
@classmethod
def end_decouple(cls):
cls.is_decoupleBlock = False
@classmethod
def overlap_all_gather(cls):
# used for grad_output all gather in RowParallel and input all gather in ColumnParallel.
if len(cls.cache) > 0:
[input_, grad_output_slice, weight, sequence_parallel, in_row] = cls.cache.pop(0)
if not sequence_parallel:
return (input_, grad_output_slice, weight, sequence_parallel, in_row), None
if not in_row:
total_input, handle = gather(input_, cls.gather_stream)
grad_output = grad_output_slice
else:
grad_output, handle = gather(grad_output_slice, cls.gather_stream)
total_input = input_
return [total_input, grad_output, weight, sequence_parallel, in_row], handle
else:
raise Exception("All Gather empty queue.")
@classmethod
def overlap_matmul(cls, grad_store_cache):
total_input, grad_output, weight, sequence_parallel, in_row = grad_store_cache
args = get_args()
if hasattr(weight, 'gmm_weight'):
inputs, group_list, group_list_data_type = total_input
if get_args().gemm_gradient_accumulation_fusion:
npu_groupmatmul_add_fp32(inputs, grad_output, group_list, weight.main_grad)
else:
grad_weight = GMMFunction.builder.load().npu_gmm([inputs.t()], [grad_output], [], group_list, 2, 0)[0]
weight.main_grad.data.add_(grad_weight.view(-1, weight.shape[-1]))
inputs.untyped_storage().resize_(0)
grad_output.untyped_storage().resize_(0)
else:
if len(grad_output.shape) > 2:
grad_output = grad_output.contiguous()
sb = grad_output.shape[0] * grad_output.shape[1]
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(
sb, grad_output.shape[2]
)
total_input = total_input.view(
sb, total_input.shape[2]
)
if get_args().gradient_accumulation_fusion:
import fused_weight_gradient_mlp_cuda
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
grad_weight = grad_output.t().matmul(total_input)
weight.main_grad.data.add_(grad_weight)
total_input.untyped_storage().resize_(0)
grad_output.untyped_storage().resize_(0)
@classmethod
def pop(cls, overlap_arg=None):
if len(cls.cache) == 0:
return
if cls.gather_stream is None:
cls.gather_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
(input_, grad_output_slice, weight, sequence_parallel, in_row), handle = cls.overlap_all_gather()
if not sequence_parallel or get_args().moe_fb_overlap:
grad_output = grad_output_slice
else:
grad_output, handle = gather(grad_output_slice, cls.gather_stream)
cls.store_grad_cache = (input_, grad_output, weight, sequence_parallel, in_row)
while len(cls.cache) > 0:
if handle is not None:
handle.wait()
next_grad_cache, handle = cls.overlap_all_gather()
cls.overlap_matmul(cls.store_grad_cache)
cls.store_grad_cache = next_grad_cache
if handle is not None:
handle.wait()
cls.overlap_matmul(cls.store_grad_cache)
cls.store_grad_cache = None
@classmethod
def pop_single(cls):
if cls.weight_grad_queue.empty():
return
cache_list = cls.weight_grad_queue.get()
assert len(cls.cache) == 0
cls.cache = cache_list
cls.pop()
\ No newline at end of file
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from .fwd import *
from .bwd import *
from .fwdbwd import *
\ No newline at end of file
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import torch
from megatron.core import parallel_state
from megatron.training import get_args
from megatron.core.transformer.moe.moe_utils import permute
from mindspeed.core.transformer.moe.comm_utils import async_all_to_all, async_all_gather, async_reduce_scatter
from mindspeed.model.transformer import should_recompute_activation
from mindspeed.core.transformer.moe.moe_utils import get_prob_backward_need_tensors
from ..modules.weight_grad_store import WeightGradStore
from ..modules.utils import run_graph_backward
def transformer_layer_backward_moe(
layer_output_grad,
layer_graph
):
self = layer_graph
args = get_args()
in_detach_stage = WeightGradStore.is_decoupleBlock
dispached_input, fc1_out, act_out, probs, indices, global_input_tokens_local_experts_indices = self.recompute_needed_tensors
ep_group = parallel_state.get_expert_model_parallel_group()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
if args.moe_tp_extend_ep:
ep_group = parallel_state.get_tensor_and_expert_parallel_group()
if tp_size > 1:
shared_expert_grad = layer_output_grad if layer_output_grad is not None else self.unperm2_graph[1].grad
_, backward_ag_shared, backward_ag_shared_handle = async_all_gather(
shared_expert_grad, parallel_state.get_tensor_model_parallel_group()
)
else:
backward_ag_shared = layer_output_grad if layer_output_grad is not None else self.unperm2_graph[1].grad
backward_ag_shared_handle = None
run_graph_backward(self.unperm2_graph, layer_output_grad, keep_grad=True)
if backward_ag_shared_handle is not None:
backward_ag_shared_handle.wait()
backward_ag_shared_handle = None
if layer_output_grad is not None:
layer_output_grad.untyped_storage().resize_(0)
_, unperm1_out_grad, handle = async_all_to_all(
self.unperm_a2a_graph[1].grad,
self.output_splits,
self.input_splits,
ep_group
)
# overlap alltoall by shared experts backward
if self.shared_experts_graph[0] is not None:
run_graph_backward(self.shared_experts_graph, backward_ag_shared)
if get_args().moe_zero_memory == 'level0' or should_recompute_activation(self.layer.layer_number):
with torch.no_grad():
recompute_act_out = self.layer.mlp.experts.activation_func(fc1_out)
act_out.untyped_storage().resize_(recompute_act_out.untyped_storage().size())
act_out.untyped_storage().copy_(recompute_act_out.untyped_storage())
recompute_act_out.untyped_storage().resize_(0)
handle.wait()
handle = None
# recomp permute1 and overlap all2all
if get_args().moe_zero_memory == 'level0':
with torch.no_grad():
input_before_perm1 = self.pre_mlp_layernorm_graph[0]
def recomp_token_permutation1(hidden_states, indices):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
permutated_local_input_tokens, _ = permute(
hidden_states, indices
)
return permutated_local_input_tokens
perm1_out = recomp_token_permutation1(input_before_perm1, indices)
_, perm_a2a_out, perm_a2a_handle = async_all_to_all(
perm1_out,
self.output_splits,
self.input_splits,
ep_group
)
run_graph_backward(self.unperm1_graph, unperm1_out_grad)
WeightGradStore.start_decouple()
run_graph_backward(self.grouped_mlp_graph, keep_grad=True) # keep for dw commputation
if not in_detach_stage:
WeightGradStore.end_decouple()
run_graph_backward(self.perm2_graph, keep_graph=True) # keep for dw commutation
if get_args().moe_zero_memory == 'level0':
perm_a2a_handle.wait()
perm_a2a_handle = None
_, perm1_out_grad, handle = async_all_to_all(
self.perm_a2a_graph[1].grad,
self.input_splits,
self.output_splits,
ep_group
)
if get_args().moe_zero_memory == 'level0':
with torch.no_grad():
recompute_fc1_input, _ = permute(perm_a2a_out, global_input_tokens_local_experts_indices)
perm_a2a_out.untyped_storage().resize_(0)
# restore fc1 input for dw computation
dispached_input.untyped_storage().resize_(recompute_fc1_input.untyped_storage().size())
dispached_input.untyped_storage().copy_(recompute_fc1_input.untyped_storage())
recompute_fc1_input.untyped_storage().resize_(0)
# dw computation
if not in_detach_stage:
WeightGradStore.pop()
handle.wait()
handle = None
run_graph_backward(self.perm1_graph, perm1_out_grad)
run_graph_backward(self.router_graph)
run_graph_backward(self.pre_mlp_layernorm_graph)
run_graph_backward(self.attn_graph)
self.recompute_needed_tensors = [None for _ in range(len(self.recompute_needed_tensors))]
return self.layer_input.grad
def transformer_layer_backward_dense(layer_output_grad, layer_graph):
run_graph_backward(layer_graph.unperm2_graph, layer_output_grad)
run_graph_backward(layer_graph.pre_mlp_layernorm_graph)
run_graph_backward(layer_graph.attn_graph)
return layer_graph.layer_input.grad
def transformer_layer_backward_noop(layer_output_grad, layer_graph):
run_graph_backward(layer_graph.unperm2_graph, layer_output_grad, keep_grad=True)
return layer_graph.layer_input.grad
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import torch
from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.training import get_args
from mindspeed.core.transformer.moe.comm_utils import async_all_to_all, async_all_gather, async_reduce_scatter
from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
from mindspeed.core.transformer.moe.moe_utils import AG_SHARED_EXPERTS_INPUTS
from mindspeed.model.transformer import should_recompute_activation
from ..modules.token_dispatcher import (
alltoall_token_perm1, alltoall_token_perm2,
alltoall_token_unperm1, alltoall_token_unperm2
)
from ..modules.attention import attention_forward
from ..modules.utils import (
detach_tensor,
NoopLayerGraph, LayerGraph,
)
def router_forward(
self,
hidden_states
):
probs, indices = self.mlp.router(hidden_states)
return probs, indices
def transformer_layer_forward_moe(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
checkpoint=False
):
# hidden_states: [s, b, h]
args = get_args()
ep_group = parallel_state.get_expert_model_parallel_group()
if args.moe_tp_extend_ep:
ep_group = parallel_state.get_tensor_and_expert_parallel_group()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_group = parallel_state.get_tensor_model_parallel_group()
use_shared_experts = hasattr(self.mlp, 'shared_experts') and self.mlp.shared_experts is not None
recomp_norm = getattr(args, 'recompute_norm', False)
detached_layer_input = detach_tensor(hidden_states, checkpoint_forward=checkpoint)
# Residual connection.
residual1 = detached_layer_input
# input_layernorm + AttentionForward
hidden_states = attention_forward(
self, detached_layer_input, residual1,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
recompute_norm=recomp_norm
)
attention_out, detached_attention_out = hidden_states, detach_tensor(hidden_states, checkpoint_forward=checkpoint)
# Residual connection.
residual2 = detached_attention_out
# Layer Norm after attention
if recomp_norm:
self.norm_ckpt2 = CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.norm_ckpt2.checkpoint(self.pre_mlp_layernorm, False, detached_attention_out)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(detached_attention_out)
# MLP.
detached_mlp_input = detach_tensor(pre_mlp_layernorm_output, checkpoint_forward=checkpoint)
if tp_size > 1 and use_shared_experts:
# shared experts tp communication
_, shared_experts_input, shared_experts_allgather_handle = async_all_gather(
detached_mlp_input, tp_group, is_use_get_global_memory_buffer=True
)
AG_SHARED_EXPERTS_INPUTS.append((shared_experts_input, shared_experts_allgather_handle))
else:
shared_experts_input, shared_experts_allgather_handle = detached_mlp_input, None
# Router forward.
probs, indices = router_forward(self, detached_mlp_input)
shared_expert_output = None
# Token Perm1 Forward
probs_detached = detach_tensor(probs, checkpoint_forward=checkpoint)
perm1_out, tokens_per_expert = alltoall_token_perm1(self.mlp.token_dispatcher, detached_mlp_input, probs_detached, indices)
if shared_experts_allgather_handle is not None:
# overlap shared experts tp comm by token perm1.
shared_experts_allgather_handle.wait()
# Async Perm A2A.
_, perm_a2a_out, perm_a2a_handle = async_all_to_all(
perm1_out,
self.mlp.token_dispatcher.output_splits,
self.mlp.token_dispatcher.input_splits,
ep_group
)
# Shared Experts Forward.
if use_shared_experts:
shared_expert_output, _ = self.mlp.shared_experts(detached_mlp_input)
if recomp_norm:
self.norm_ckpt2.discard_output()
# overlap perm a2a by shared experts computation.
perm_a2a_handle.wait()
# perm1_out tensor storage is not need by backward,
# but backward func of perm1_out is needed, so resize the storage but keep tensor.
perm1_out.untyped_storage().resize_(0)
if tp_size > 1 and use_shared_experts:
# tp comm for shared experts
share_experts_graph, shared_expert_output, rs_shared_experts_handle = async_reduce_scatter(
shared_expert_output, tp_group
)
else:
share_experts_graph = shared_expert_output
rs_shared_experts_handle = None
detached_perm_a2a_out = detach_tensor(perm_a2a_out, checkpoint_forward=checkpoint)
# Token Perm2 Forward.
dispached_input = alltoall_token_perm2(self.mlp.token_dispatcher, detached_perm_a2a_out)
perm_a2a_out.untyped_storage().resize_(0)
# Grouped MLP Forward
detached_dispached_input = detach_tensor(dispached_input, checkpoint_forward=checkpoint)
(expert_output, fc1_output, act_out), _ = self.mlp.experts(detached_dispached_input, tokens_per_expert)
if args.moe_zero_memory == 'level0':
dispached_input.untyped_storage().resize_(0)
recompute_needed_tensors = [dispached_input, fc1_output, act_out, probs, indices,
self.mlp.token_dispatcher.global_input_tokens_local_experts_indices]
else:
if should_recompute_activation(self.layer_number):
recompute_needed_tensors = [None, fc1_output, act_out, None, None, None]
else:
recompute_needed_tensors = [None, None, None, None, None, None]
detached_expert_output = detach_tensor(expert_output, checkpoint_forward=checkpoint)
# Token Unperm1 Forward
unperm1_out = alltoall_token_unperm1(self.mlp.token_dispatcher, detached_expert_output, None)
expert_output.untyped_storage().resize_(0)
if rs_shared_experts_handle is not None:
# overlap shared experts tp comm by token perm2 + gmm
rs_shared_experts_handle.wait()
# share_experts_graph tensor storage is not need by backward,
# but backward func of share_experts_graph is needed, so resize the storage but keep tensor.
share_experts_graph.untyped_storage().resize_(0)
# Launch Token Unperm2 A2A
_, unperm_a2a_out, unperm_a2a_handle = async_all_to_all(
unperm1_out,
self.mlp.token_dispatcher.input_splits,
self.mlp.token_dispatcher.output_splits,
ep_group
)
unperm_a2a_handle.wait()
# unperm1_out tensor storage is not need by backward,
# but backward func of unperm1_out is needed, so resize the storage but keep tensor.
unperm1_out.untyped_storage().resize_(0)
detached_unperm_a2a_out = detach_tensor(unperm_a2a_out, checkpoint_forward=checkpoint)
route_expert_output, _ = alltoall_token_unperm2(self.mlp.token_dispatcher, detached_unperm_a2a_out)
if use_shared_experts:
detached_shared_expert_output = detach_tensor(shared_expert_output, checkpoint_forward=checkpoint)
mlp_output = route_expert_output + detached_shared_expert_output
shared_expert_output.untyped_storage().resize_(0)
else:
detached_shared_expert_output = None
share_experts_graph = None
mlp_output = route_expert_output
if recomp_norm:
mlp_output.register_hook(self.norm_ckpt2.recompute)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
(mlp_output, None), residual2, self.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
saved_tensors = (
(attention_out, detached_attention_out),
(pre_mlp_layernorm_output, detached_mlp_input),
(probs, probs_detached),
(perm1_out, None), # perm1 graph
(None, detached_perm_a2a_out),
(dispached_input, detached_dispached_input), # perm2 graph
(expert_output, detached_expert_output), # grouped mlp graph
(unperm1_out, None), # unperm1 graph
(None, detached_unperm_a2a_out),
(output, None), # unperm2 graph
(share_experts_graph, detached_shared_expert_output),
detached_layer_input
)
graph = LayerGraph(
saved_tensors, recompute_needed_tensors,
self.mlp.token_dispatcher.input_splits, self.mlp.token_dispatcher.output_splits, self,
checkpointed=checkpoint
)
return output, context, graph
def transformer_layer_forward_dense(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
checkpoint=False
):
# hidden_states: [s, b, h]
args = get_args()
recomp_norm = getattr(args, 'recompute_norm', False)
detached_layer_input = detach_tensor(hidden_states, checkpoint_forward=checkpoint)
# Residual connection.
residual1 = detached_layer_input
# input_layernorm + AttentionForward
hidden_states = attention_forward(
self, detached_layer_input, residual1,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
recompute_norm=recomp_norm
)
attention_graph, detached_attention_out = hidden_states, detach_tensor(hidden_states, checkpoint_forward=checkpoint)
# Residual connection.
residual2 = detached_attention_out
if recomp_norm:
self.norm_ckpt2 = CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.norm_ckpt2.checkpoint(self.pre_mlp_layernorm, False, detached_attention_out)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(detached_attention_out)
# MLP.
detached_mlp_input = detach_tensor(pre_mlp_layernorm_output, checkpoint_forward=checkpoint)
mlp_output_with_bias = self.mlp(detached_mlp_input)
if recomp_norm:
self.norm_ckpt2.discard_output()
mlp_output_with_bias[0].register_hook(self.norm_ckpt2.recompute)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual2, self.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
saved_tensors = (
(attention_graph, detached_attention_out),
(pre_mlp_layernorm_output, detached_mlp_input),
(None, None),
(None, None),
(None, None),
(None, None),
(None, None),
(None, None),
(None, None),
(output, None),
(None, None),
detached_layer_input
)
graph = LayerGraph(
saved_tensors, [], None, None, self,
checkpointed=checkpoint
)
return output, context, graph
def transformer_layer_forward_noop(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
checkpoint=False
):
detached_layer_input = detach_tensor(hidden_states, checkpoint_forward=checkpoint)
output = detached_layer_input.clone()
return output, context, NoopLayerGraph(detached_layer_input, output, self, checkpointed=checkpoint)
\ No newline at end of file
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from contextlib import nullcontext
import torch
from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.training import get_args
from megatron.core.transformer.moe.moe_utils import permute
from mindspeed.core.transformer.moe.comm_utils import async_all_to_all, async_all_gather, async_reduce_scatter
from mindspeed.model.transformer import should_recompute_activation
from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
from mindspeed.core.transformer.moe.moe_utils import AG_SHARED_EXPERTS_INPUTS
from ..modules.token_dispatcher import (
alltoall_token_perm1, alltoall_token_perm2,
alltoall_token_unperm1, alltoall_token_unperm2
)
from ..modules.weight_grad_store import WeightGradStore
from ..modules.attention import (
attention_forward, set_async_alltoall_inputs, get_async_alltoall_outputs
)
from ..modules.utils import (
detach_tensor, run_graph_backward, LayerGraph, is_p2p_comm_needed,
p2p_comm_helper, P2PCommOutput, P2PCommParams
)
def router_forward(
self,
hidden_states
):
probs, indices = self.mlp.router(hidden_states)
return probs, indices
def transformer_layer_forward_dense_backward_moe_overlaping(
fwd_layer,
hidden_states,
attention_mask,
bwd_layer_output_grad=None,
bwd_layer_graph: LayerGraph = None,
bwd_unperm_a2a_handle=None,
next_bwd_layer_graph: LayerGraph = None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
pp_comm_params: P2PCommParams = None,
bwd_pp_comm_params: P2PCommParams = None,
checkpoint=False
):
tp_size = parallel_state.get_tensor_model_parallel_world_size()
if checkpoint:
checkpoint_context = torch.no_grad()
else:
checkpoint_context = nullcontext()
args = get_args()
ep_group = parallel_state.get_expert_model_parallel_group()
if args.moe_tp_extend_ep:
ep_group = parallel_state.get_tensor_and_expert_parallel_group()
recomp_norm = getattr(args, 'recompute_norm', False)
bwd_dispached_input, bwd_fc1_out, bwd_act_out, bwd_probs, bwd_indices, global_input_tokens_local_experts_indices = bwd_layer_graph.recompute_needed_tensors
# Unperm2 Bwd
# check if backward unpermutation alltoall is launched at bwd layer before
if bwd_unperm_a2a_handle is None:
run_graph_backward(bwd_layer_graph.unperm2_graph, bwd_layer_output_grad)
# Async Unperm A2A
_, unperm1_out_grad, bwd_unperm_a2a_handle = async_all_to_all(
bwd_layer_graph.unperm_a2a_graph[1].grad,
bwd_layer_graph.output_splits,
bwd_layer_graph.input_splits,
ep_group
)
else:
unperm1_out_grad = bwd_layer_output_grad
if args.moe_zero_memory == 'level0':
with torch.no_grad():
bwd_input_before_perm1 = bwd_layer_graph.pre_mlp_layernorm_graph[0]
def recomp_token_permutation1(hidden_states, indices):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
permutated_local_input_tokens, _ = permute(
hidden_states, indices
)
return permutated_local_input_tokens
bwd_perm1_out = recomp_token_permutation1(bwd_input_before_perm1, bwd_indices)
with checkpoint_context:
# Atten Fwd
detached_layer_input = detach_tensor(hidden_states, checkpoint_forward=checkpoint)
# Residual connection.
residual1 = detached_layer_input
# input_layernorm + AttentionForward
hidden_states = attention_forward(
fwd_layer, detached_layer_input, residual1,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
recompute_norm=recomp_norm
)
attention_graph, detached_attention_out = hidden_states, detach_tensor(hidden_states,
checkpoint_forward=checkpoint)
# Residual connection.
residual2 = detached_attention_out
if recomp_norm:
fwd_layer.norm_ckpt2 = CheckpointWithoutOutput()
pre_mlp_layernorm_output = fwd_layer.norm_ckpt2.checkpoint(fwd_layer.pre_mlp_layernorm, False,
detached_attention_out)
else:
pre_mlp_layernorm_output = fwd_layer.pre_mlp_layernorm(detached_attention_out)
if args.moe_zero_memory == 'level0':
_, bwd_perm_a2a_out, bwd_recomp_perm_a2a_handle = async_all_to_all(
bwd_perm1_out,
bwd_layer_graph.output_splits,
bwd_layer_graph.input_splits,
ep_group,
event=bwd_unperm_a2a_handle,
stream=torch.npu.current_stream()
)
if args.moe_zero_memory == 'level0' or should_recompute_activation(bwd_layer_graph.layer.layer_number):
with torch.no_grad():
recompute_act_out = bwd_layer_graph.layer.mlp.experts.activation_func(bwd_fc1_out)
bwd_act_out.untyped_storage().resize_(recompute_act_out.untyped_storage().size())
bwd_act_out.untyped_storage().copy_(recompute_act_out.untyped_storage())
recompute_act_out.untyped_storage().resize_(0)
bwd_unperm_a2a_handle.wait()
bwd_unperm_a2a_handle = None
run_graph_backward(bwd_layer_graph.unperm1_graph, unperm1_out_grad)
unperm1_out_grad.untyped_storage().resize_(0)
WeightGradStore.start_decouple()
run_graph_backward(bwd_layer_graph.grouped_mlp_graph, keep_grad=True) # keep for dw
WeightGradStore.end_decouple()
run_graph_backward(bwd_layer_graph.perm2_graph, keep_graph=True) # keep for dw
if args.moe_zero_memory == 'level0':
with torch.no_grad():
bwd_recomp_perm_a2a_handle.wait()
bwd_recomp_perm_a2a_handle = None
recompute_fc1_input, _ = permute(bwd_perm_a2a_out, global_input_tokens_local_experts_indices)
bwd_perm_a2a_out.untyped_storage().resize_(0)
if tp_size > 1:
shared_expert_grad = bwd_layer_graph.shared_experts_graph[1].grad
_, backward_ag_shared, backward_ag_shared_handle = async_all_gather(
shared_expert_grad, parallel_state.get_tensor_model_parallel_group()
)
else:
backward_ag_shared = bwd_layer_graph.shared_experts_graph[1].grad
backward_ag_shared_handle = None
_, perm1_out_grad, bwd_perm_a2a_handle = async_all_to_all(
bwd_layer_graph.perm_a2a_graph[1].grad,
bwd_layer_graph.input_splits,
bwd_layer_graph.output_splits,
ep_group,
event=backward_ag_shared_handle
)
# Grouped MLP dw computation
with checkpoint_context:
# MLP Forward
detached_mlp_input = detach_tensor(pre_mlp_layernorm_output, checkpoint_forward=checkpoint)
mlp_output_with_bias = fwd_layer.mlp(detached_mlp_input)
if recomp_norm:
fwd_layer.norm_ckpt2.discard_output()
mlp_output_with_bias[0].register_hook(fwd_layer.norm_ckpt2.recompute)
bwd_perm_a2a_handle.wait()
bwd_perm_a2a_handle = None
run_graph_backward(bwd_layer_graph.perm1_graph, perm1_out_grad)
perm1_out_grad.untyped_storage().resize_(0)
WeightGradStore.start_decouple()
if backward_ag_shared_handle is not None:
backward_ag_shared_handle.wait()
backward_ag_shared_handle = None
shared_expert_grad.untyped_storage().resize_(0)
run_graph_backward(bwd_layer_graph.shared_experts_graph, backward_ag_shared, keep_grad=True) # dw computation
WeightGradStore.end_decouple()
run_graph_backward(bwd_layer_graph.router_graph)
run_graph_backward(bwd_layer_graph.pre_mlp_layernorm_graph, keep_graph=True)
WeightGradStore.start_decouple()
run_graph_backward(bwd_layer_graph.attn_graph, keep_grad=True)
WeightGradStore.end_decouple()
if next_bwd_layer_graph is not None and getattr(next_bwd_layer_graph, 'is_moe_layer', False):
run_graph_backward(next_bwd_layer_graph.unperm2_graph, bwd_layer_graph.layer_input.grad, keep_graph=True)
next_layer_output_grad, next_bwd_unperm_a2a_handle = bwd_layer_graph.layer_input.grad, None
if next_bwd_layer_graph is not None and getattr(next_bwd_layer_graph, 'is_moe_layer', False):
_, next_layer_output_grad, next_bwd_unperm_a2a_handle = async_all_to_all(
next_bwd_layer_graph.unperm_a2a_graph[1].grad,
next_bwd_layer_graph.output_splits,
next_bwd_layer_graph.input_splits,
ep_group
)
with checkpoint_context:
with fwd_layer.bias_dropout_add_exec_handler():
hidden_states = fwd_layer.mlp_bda(fwd_layer.training, fwd_layer.config.bias_dropout_fusion)(
mlp_output_with_bias, residual2, fwd_layer.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
# handle fwd p2p communication
next_iter_input_tensor, fwd_p2p_handles = None, None
fwd_pp_comm_params = pp_comm_params
if is_p2p_comm_needed(fwd_pp_comm_params):
next_iter_input_tensor, fwd_p2p_handles = p2p_comm_helper(fwd_pp_comm_params, output)
# handle bwd p2p communication
next_iter_output_tensor_grad, bwd_p2p_handles = None, None
if is_p2p_comm_needed(bwd_pp_comm_params):
next_iter_output_tensor_grad, bwd_p2p_handles = p2p_comm_helper(bwd_pp_comm_params, bwd_layer_graph.layer_input.grad)
if args.moe_zero_memory == 'level0':
# restore fc1 input for dw computation
bwd_dispached_input.untyped_storage().resize_(recompute_fc1_input.untyped_storage().size())
bwd_dispached_input.untyped_storage().copy_(recompute_fc1_input.untyped_storage())
recompute_fc1_input.untyped_storage().resize_(0)
WeightGradStore.pop()
saved_tensors = (
(attention_graph, detached_attention_out),
(pre_mlp_layernorm_output, detached_mlp_input),
(None, None),
(None, None),
(None, None),
(None, None), # perm2 graph
(None, None), # grouped mlp graph
(None, None), # unperm1 graph
(None, None),
(output, None), # unperm2 graph
(None, None),
detached_layer_input
)
graph = LayerGraph(
saved_tensors, [], None, None, fwd_layer,
checkpointed=checkpoint
)
for tensor in bwd_layer_graph.recompute_needed_tensors:
if tensor is not None:
tensor.untyped_storage().resize_(0)
return (output, context, graph,
(next_layer_output_grad, next_bwd_unperm_a2a_handle),
P2PCommOutput(next_iter_input_tensor, next_iter_output_tensor_grad, fwd_p2p_handles, bwd_p2p_handles, bwd_layer_graph.layer_input.grad))
def transformer_layer_forward_moe_backward_dense_overlaping(
fwd_layer,
hidden_states,
attention_mask,
bwd_layer_output_grad=None,
bwd_layer_graph: LayerGraph = None,
bwd_unperm_a2a_handle=None,
next_bwd_layer_graph: LayerGraph = None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
pp_comm_params: P2PCommParams = None,
bwd_pp_comm_params: P2PCommParams = None,
checkpoint=False
):
args = get_args()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_group = parallel_state.get_tensor_model_parallel_group()
use_shared_experts = hasattr(fwd_layer.mlp, 'shared_experts') and fwd_layer.mlp.shared_experts is not None
if checkpoint:
checkpoint_context = torch.no_grad()
else:
checkpoint_context = nullcontext()
args = get_args()
ep_group = parallel_state.get_expert_model_parallel_group()
if args.moe_tp_extend_ep:
ep_group = parallel_state.get_tensor_and_expert_parallel_group()
recomp_norm = getattr(args, 'recompute_norm', False)
with checkpoint_context:
# Atten Fwd
detached_layer_input = detach_tensor(hidden_states, checkpoint_forward=checkpoint)
# Residual connection.
residual1 = detached_layer_input
# input_layernorm + AttentionForward
hidden_states = attention_forward(
fwd_layer, detached_layer_input, residual1,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
recompute_norm=recomp_norm
)
attention_graph, detached_attention_out = hidden_states, detach_tensor(hidden_states)
# Residual connection.
residual2 = detached_attention_out
if recomp_norm:
fwd_layer.norm_ckpt2 = CheckpointWithoutOutput()
pre_mlp_layernorm_output = fwd_layer.norm_ckpt2.checkpoint(fwd_layer.pre_mlp_layernorm, False, detached_attention_out)
else:
pre_mlp_layernorm_output = fwd_layer.pre_mlp_layernorm(detached_attention_out)
# MLP.
detached_mlp_input = detach_tensor(pre_mlp_layernorm_output, checkpoint_forward=checkpoint)
probs, indices = router_forward(fwd_layer, detached_mlp_input)
# Token Permutation Forward
probs_detached = detach_tensor(probs, checkpoint_forward=checkpoint)
perm1_out, tokens_per_expert = alltoall_token_perm1(fwd_layer.mlp.token_dispatcher, detached_mlp_input, probs_detached, indices)
_, perm_a2a_out, perm_a2a_handle = async_all_to_all(
perm1_out,
fwd_layer.mlp.token_dispatcher.output_splits,
fwd_layer.mlp.token_dispatcher.input_splits,
ep_group
)
WeightGradStore.start_decouple()
run_graph_backward(bwd_layer_graph.unperm2_graph, bwd_layer_output_grad, keep_grad=True) # keep for dw
run_graph_backward(bwd_layer_graph.pre_mlp_layernorm_graph, keep_graph=True)
WeightGradStore.end_decouple()
perm_a2a_handle.wait()
perm_a2a_handle = None
# Grouped MLP dw computation
with checkpoint_context:
detached_perm_a2a_out = detach_tensor(perm_a2a_out, checkpoint_forward=checkpoint)
dispached_input = alltoall_token_perm2(fwd_layer.mlp.token_dispatcher, detached_perm_a2a_out)
perm_a2a_out.untyped_storage().resize_(0)
if tp_size > 1 and use_shared_experts:
_, shared_experts_input, shared_experts_allgather_handle = async_all_gather(
detached_mlp_input, tp_group, is_use_get_global_memory_buffer=True
)
AG_SHARED_EXPERTS_INPUTS.append((shared_experts_input, shared_experts_allgather_handle))
else:
shared_experts_input, shared_experts_allgather_handle = detached_mlp_input, None
# Grouped MLP Forward
detached_dispached_input = detach_tensor(dispached_input, checkpoint_forward=checkpoint)
(expert_output, fc1_output, act_out), _ = fwd_layer.mlp.experts(detached_dispached_input, tokens_per_expert)
if args.moe_zero_memory == 'level0':
dispached_input.untyped_storage().resize_(0)
recompute_needed_tensors = [dispached_input, fc1_output, act_out, probs, indices,
fwd_layer.mlp.token_dispatcher.global_input_tokens_local_experts_indices]
else:
if should_recompute_activation(fwd_layer.layer_number):
recompute_needed_tensors = [None, fc1_output, act_out, None, None, None]
else:
recompute_needed_tensors = [None, None, None, None, None, None]
detached_expert_output = detach_tensor(expert_output, checkpoint_forward=checkpoint)
# Token Unpermutaion Forward
unperm1_out = alltoall_token_unperm1(fwd_layer.mlp.token_dispatcher, detached_expert_output, None)
expert_output.untyped_storage().resize_(0)
if shared_experts_allgather_handle is not None:
shared_experts_allgather_handle.wait()
shared_experts_allgather_handle = None
_, unperm_a2a_out, unperm_a2a_handle = async_all_to_all(
unperm1_out,
fwd_layer.mlp.token_dispatcher.input_splits,
fwd_layer.mlp.token_dispatcher.output_splits,
ep_group
)
share_experts_graph = None
if use_shared_experts:
shared_expert_output, _ = fwd_layer.mlp.shared_experts(detached_mlp_input)
if tp_size > 1:
share_experts_graph, shared_expert_output, rs_shared_experts_handle = async_reduce_scatter(
shared_expert_output, tp_group
)
rs_shared_experts_handle.wait()
rs_shared_experts_handle = None
share_experts_graph.untyped_storage().resize_(0)
else:
share_experts_graph = shared_expert_output
if recomp_norm:
fwd_layer.norm_ckpt2.discard_output()
WeightGradStore.start_decouple()
run_graph_backward(bwd_layer_graph.attn_graph, keep_grad=True)
WeightGradStore.end_decouple()
if next_bwd_layer_graph is not None and getattr(next_bwd_layer_graph, 'is_moe_layer', False):
run_graph_backward(next_bwd_layer_graph.unperm2_graph, bwd_layer_graph.layer_input.grad, keep_graph=True)
unperm_a2a_handle.wait()
unperm_a2a_handle = None
unperm1_out.untyped_storage().resize_(0)
next_layer_output_grad, next_bwd_unperm_a2a_handle = bwd_layer_graph.layer_input.grad, None
if next_bwd_layer_graph is not None and getattr(next_bwd_layer_graph, 'is_moe_layer', False):
_, next_layer_output_grad, next_bwd_unperm_a2a_handle = async_all_to_all(
next_bwd_layer_graph.unperm_a2a_graph[1].grad,
next_bwd_layer_graph.output_splits,
next_bwd_layer_graph.input_splits,
ep_group
)
with checkpoint_context:
detached_unperm_a2a_out = detach_tensor(unperm_a2a_out, checkpoint_forward=checkpoint)
route_expert_output, _ = alltoall_token_unperm2(fwd_layer.mlp.token_dispatcher, detached_unperm_a2a_out)
if hasattr(fwd_layer.mlp, 'shared_experts') and fwd_layer.mlp.shared_experts is not None:
detached_shared_expert_output = detach_tensor(shared_expert_output, checkpoint_forward=checkpoint)
mlp_output = route_expert_output + detached_shared_expert_output
shared_expert_output.untyped_storage().resize_(0)
else:
detached_shared_expert_output = None
mlp_output = route_expert_output
if recomp_norm:
mlp_output.register_hook(fwd_layer.norm_ckpt2.recompute)
with fwd_layer.bias_dropout_add_exec_handler():
hidden_states = fwd_layer.mlp_bda(fwd_layer.training, fwd_layer.config.bias_dropout_fusion)(
(mlp_output, None), residual2, fwd_layer.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
# handle fwd p2p communication
next_iter_input_tensor, fwd_p2p_handles = None, None
fwd_pp_comm_params = pp_comm_params
if is_p2p_comm_needed(fwd_pp_comm_params):
next_iter_input_tensor, fwd_p2p_handles = p2p_comm_helper(fwd_pp_comm_params, output)
# handle bwd p2p communication
next_iter_output_tensor_grad, bwd_p2p_handles = None, None
if is_p2p_comm_needed(bwd_pp_comm_params):
next_iter_output_tensor_grad, bwd_p2p_handles = p2p_comm_helper(bwd_pp_comm_params, bwd_layer_graph.layer_input.grad)
WeightGradStore.pop()
saved_tensors = (
(attention_graph, detached_attention_out),
(pre_mlp_layernorm_output, detached_mlp_input),
(probs, probs_detached),
(perm1_out, None), # perm1 graph
(None, detached_perm_a2a_out),
(dispached_input, detached_dispached_input), # perm2 graph
(expert_output, detached_expert_output), # grouped mlp graph
(unperm1_out, None), # unperm1 graph
(None, detached_unperm_a2a_out),
(output, None), # unperm2 graph
(share_experts_graph, detached_shared_expert_output),
detached_layer_input
)
graph = LayerGraph(
saved_tensors, recompute_needed_tensors,
fwd_layer.mlp.token_dispatcher.input_splits, fwd_layer.mlp.token_dispatcher.output_splits, fwd_layer,
checkpointed=checkpoint
)
for tensor in bwd_layer_graph.recompute_needed_tensors:
if tensor is not None:
tensor.untyped_storage().resize_(0)
return (output, context, graph,
(next_layer_output_grad, next_bwd_unperm_a2a_handle),
P2PCommOutput(next_iter_input_tensor, next_iter_output_tensor_grad, fwd_p2p_handles, bwd_p2p_handles, bwd_layer_graph.layer_input.grad))
def transformer_layer_forward_dense_backward_dense_overlaping(
fwd_layer,
hidden_states,
attention_mask,
bwd_layer_output_grad=None,
bwd_layer_graph: LayerGraph = None,
bwd_unperm_a2a_handle=None,
next_bwd_layer_graph: LayerGraph = None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
pp_comm_params: P2PCommParams = None,
bwd_pp_comm_params: P2PCommParams = None,
checkpoint=False
):
if checkpoint:
checkpoint_context = torch.no_grad()
else:
checkpoint_context = nullcontext()
args = get_args()
ep_group = parallel_state.get_expert_model_parallel_group()
if args.moe_tp_extend_ep:
ep_group = parallel_state.get_tensor_and_expert_parallel_group()
recomp_norm = getattr(args, 'recompute_norm', False)
with checkpoint_context:
# Atten Fwd
detached_layer_input = detach_tensor(hidden_states, checkpoint_forward=checkpoint)
# Residual connection.
residual1 = detached_layer_input
# input_layernorm + AttentionForward
hidden_states = attention_forward(
fwd_layer, detached_layer_input, residual1,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
recompute_norm=recomp_norm
)
attention_graph, detached_attention_out = hidden_states, detach_tensor(hidden_states, checkpoint_forward=checkpoint)
# Residual connection.
residual2 = detached_attention_out
if recomp_norm:
fwd_layer.norm_ckpt2 = CheckpointWithoutOutput()
pre_mlp_layernorm_output = fwd_layer.norm_ckpt2.checkpoint(fwd_layer.pre_mlp_layernorm, False, detached_attention_out)
else:
pre_mlp_layernorm_output = fwd_layer.pre_mlp_layernorm(detached_attention_out)
# MLP.
detached_mlp_input = detach_tensor(pre_mlp_layernorm_output, checkpoint_forward=checkpoint)
mlp_output_with_bias = fwd_layer.mlp(detached_mlp_input)
if recomp_norm:
fwd_layer.norm_ckpt2.discard_output()
mlp_output_with_bias[0].register_hook(fwd_layer.norm_ckpt2.recompute)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with fwd_layer.bias_dropout_add_exec_handler():
hidden_states = fwd_layer.mlp_bda(fwd_layer.training, fwd_layer.config.bias_dropout_fusion)(
mlp_output_with_bias, residual2, fwd_layer.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
# handle fwd p2p communication
next_iter_input_tensor, fwd_p2p_handles = None, None
fwd_pp_comm_params = pp_comm_params
if is_p2p_comm_needed(fwd_pp_comm_params):
next_iter_input_tensor, fwd_p2p_handles = p2p_comm_helper(fwd_pp_comm_params, output)
# Detach backward into dx/dw
WeightGradStore.start_decouple()
run_graph_backward(bwd_layer_graph.unperm2_graph, bwd_layer_output_grad, keep_grad=True) # keep for dw
run_graph_backward(bwd_layer_graph.pre_mlp_layernorm_graph, keep_graph=True)
run_graph_backward(bwd_layer_graph.attn_graph, keep_grad=True)
WeightGradStore.end_decouple()
if next_bwd_layer_graph is not None and getattr(next_bwd_layer_graph, 'is_moe_layer', False):
run_graph_backward(next_bwd_layer_graph.unperm2_graph, bwd_layer_graph.layer_input.grad, keep_graph=True)
next_layer_output_grad, next_bwd_unperm_a2a_handle = bwd_layer_graph.layer_input.grad, None
if next_bwd_layer_graph is not None and getattr(next_bwd_layer_graph, 'is_moe_layer', False):
_, next_layer_output_grad, next_bwd_unperm_a2a_handle = async_all_to_all(
next_bwd_layer_graph.unperm_a2a_graph[1].grad,
next_bwd_layer_graph.output_splits,
next_bwd_layer_graph.input_splits,
ep_group
)
# handle bwd p2p communication
next_iter_output_tensor_grad, bwd_p2p_handles = None, None
if is_p2p_comm_needed(bwd_pp_comm_params):
next_iter_output_tensor_grad, bwd_p2p_handles = p2p_comm_helper(bwd_pp_comm_params, bwd_layer_graph.layer_input.grad)
WeightGradStore.pop()
saved_tensors = (
(attention_graph, detached_attention_out),
(pre_mlp_layernorm_output, detached_mlp_input),
(None, None),
(None, None), # perm1 graph
(None, None),
(None, None), # perm2 graph
(None, None), # grouped mlp graph
(None, None), # unperm1 graph
(None, None),
(output, None), # unperm2 graph
(None, None),
detached_layer_input
)
graph = LayerGraph(
saved_tensors, [], None, None, fwd_layer,
checkpointed=checkpoint
)
for tensor in bwd_layer_graph.recompute_needed_tensors:
if tensor is not None:
tensor.untyped_storage().resize_(0)
return (output, context, graph,
(next_layer_output_grad, next_bwd_unperm_a2a_handle),
P2PCommOutput(next_iter_input_tensor, next_iter_output_tensor_grad, fwd_p2p_handles, bwd_p2p_handles, bwd_layer_graph.layer_input.grad))
def transformer_layer_forward_moe_backward_moe_overlaping(
fwd_layer,
hidden_states,
attention_mask,
bwd_layer_output_grad=None,
bwd_layer_graph: LayerGraph = None,
bwd_unperm_a2a_handle=None,
next_bwd_layer_graph: LayerGraph = None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
pp_comm_params: P2PCommParams = None,
bwd_pp_comm_params: P2PCommParams = None,
checkpoint=False
):
if checkpoint:
checkpoint_context = torch.no_grad()
else:
checkpoint_context = nullcontext()
args = get_args()
ep_group = parallel_state.get_expert_model_parallel_group()
if args.moe_tp_extend_ep:
ep_group = parallel_state.get_tensor_and_expert_parallel_group()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_group = parallel_state.get_tensor_model_parallel_group()
use_shared_experts = hasattr(fwd_layer.mlp, 'shared_experts') and fwd_layer.mlp.shared_experts is not None
recomp_norm = getattr(args, 'recompute_norm', False)
bwd_dispached_input, bwd_fc1_out, bwd_act_out, bwd_probs, bwd_indices, global_input_tokens_local_experts_indices = bwd_layer_graph.recompute_needed_tensors
a2a_hooked_on_attention = getattr(fwd_layer.self_attention, 'a2a_hooked_on_attention', False)
# Unperm2 Bwd
# check if backward unpermutation alltoall is launched at bwd layer before
if bwd_unperm_a2a_handle is None:
run_graph_backward(bwd_layer_graph.unperm2_graph, bwd_layer_output_grad)
# Async Unperm A2A
if tp_size > 1 and a2a_hooked_on_attention:
set_async_alltoall_inputs(
bwd_layer_graph.unperm_a2a_graph[1].grad,
bwd_layer_graph.output_splits,
bwd_layer_graph.input_splits,
ep_group
)
else:
_, unperm1_out_grad, bwd_unperm_a2a_handle = async_all_to_all(
bwd_layer_graph.unperm_a2a_graph[1].grad,
bwd_layer_graph.output_splits,
bwd_layer_graph.input_splits,
ep_group
)
else:
unperm1_out_grad = bwd_layer_output_grad
if args.moe_zero_memory == 'level0':
with torch.no_grad():
bwd_input_before_perm1 = bwd_layer_graph.pre_mlp_layernorm_graph[0]
def recomp_token_permutation1(hidden_states, indices):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
permutated_local_input_tokens, _ = permute(
hidden_states, indices
)
return permutated_local_input_tokens
bwd_perm1_out = recomp_token_permutation1(bwd_input_before_perm1, bwd_indices)
with checkpoint_context:
# Residual connection.
detached_layer_input = detach_tensor(hidden_states)
residual1 = detached_layer_input
# input_layernorm + AttentionForward
hidden_states = attention_forward(
fwd_layer, detached_layer_input, residual1,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
recompute_norm=recomp_norm
)
if bwd_unperm_a2a_handle is None and tp_size > 1 and a2a_hooked_on_attention:
unperm1_out_grad, bwd_unperm_a2a_handle = get_async_alltoall_outputs()
attention_graph, detached_attention_out = hidden_states, detach_tensor(hidden_states)
# Residual connection.
residual2 = detached_attention_out
if recomp_norm:
fwd_layer.norm_ckpt2 = CheckpointWithoutOutput()
pre_mlp_layernorm_output = fwd_layer.norm_ckpt2.checkpoint(fwd_layer.pre_mlp_layernorm, False, detached_attention_out)
else:
pre_mlp_layernorm_output = fwd_layer.pre_mlp_layernorm(detached_attention_out)
# MLP.
detached_mlp_input = detach_tensor(pre_mlp_layernorm_output)
probs, indices = router_forward(fwd_layer, detached_mlp_input)
if tp_size > 1 and use_shared_experts:
# launch tp comm here and wait last aync comm finish
_, shared_experts_input, shared_experts_allgather_handle = async_all_gather(
detached_mlp_input, tp_group, event=bwd_unperm_a2a_handle,
stream=torch.npu.current_stream() if bwd_unperm_a2a_handle else None,
is_use_get_global_memory_buffer=True
)
AG_SHARED_EXPERTS_INPUTS.append((shared_experts_input, shared_experts_allgather_handle))
else:
shared_experts_input, shared_experts_allgather_handle = detached_mlp_input, None
# Token Permutation Forward
probs_detached = detach_tensor(probs)
perm1_out, tokens_per_expert = alltoall_token_perm1(fwd_layer.mlp.token_dispatcher, detached_mlp_input, probs_detached, indices)
if args.moe_zero_memory == 'level0' or should_recompute_activation(bwd_layer_graph.layer.layer_number):
with torch.no_grad():
recompute_act_out = bwd_layer_graph.layer.mlp.experts.activation_func(bwd_fc1_out)
bwd_act_out.untyped_storage().resize_(recompute_act_out.untyped_storage().size())
bwd_act_out.untyped_storage().copy_(recompute_act_out.untyped_storage())
recompute_act_out.untyped_storage().resize_(0)
last_comm_handle = shared_experts_allgather_handle if shared_experts_allgather_handle else bwd_unperm_a2a_handle
if args.moe_zero_memory == 'level0':
_, bwd_perm_a2a_out, bwd_recomp_perm_a2a_handle = async_all_to_all(
bwd_perm1_out,
bwd_layer_graph.output_splits,
bwd_layer_graph.input_splits,
ep_group,
event=last_comm_handle,
stream=torch.npu.current_stream() if last_comm_handle else None
)
last_comm_handle = bwd_recomp_perm_a2a_handle
with checkpoint_context:
_, perm_a2a_out, perm_a2a_handle = async_all_to_all(
perm1_out,
fwd_layer.mlp.token_dispatcher.output_splits,
fwd_layer.mlp.token_dispatcher.input_splits,
ep_group,
event=last_comm_handle,
stream=torch.npu.current_stream() if last_comm_handle else None
)
last_comm_handle = perm_a2a_handle
with checkpoint_context:
shared_expert_output = None
if use_shared_experts:
if shared_experts_allgather_handle is not None:
shared_experts_allgather_handle.wait()
shared_experts_allgather_handle = None
shared_expert_output, _ = fwd_layer.mlp.shared_experts(detached_mlp_input)
if tp_size > 1:
# launch tp comm after permf a2a and wait until shared experts computation finish.
share_experts_graph, shared_expert_output, rs_shared_experts_handle = async_reduce_scatter(
shared_expert_output, tp_group, event=last_comm_handle,
stream=torch.npu.current_stream() if last_comm_handle else None
)
last_comm_handle = rs_shared_experts_handle
else:
share_experts_graph = shared_expert_output
rs_shared_experts_handle = None
if recomp_norm:
fwd_layer.norm_ckpt2.discard_output()
bwd_unperm_a2a_handle.wait()
bwd_unperm_a2a_handle = None
run_graph_backward(bwd_layer_graph.unperm1_graph, unperm1_out_grad)
unperm1_out_grad.untyped_storage().resize_(0)
WeightGradStore.start_decouple()
run_graph_backward(bwd_layer_graph.grouped_mlp_graph, keep_grad=True) # keep for dw
WeightGradStore.end_decouple()
run_graph_backward(bwd_layer_graph.perm2_graph, keep_graph=True) # keep for dw
perm_a2a_handle.wait()
perm_a2a_handle = None
perm1_out.untyped_storage().resize_(0)
_, perm1_out_grad, bwd_perm_a2a_handle = async_all_to_all(
bwd_layer_graph.perm_a2a_graph[1].grad,
bwd_layer_graph.input_splits,
bwd_layer_graph.output_splits,
ep_group,
event=last_comm_handle,
stream=torch.npu.current_stream() if last_comm_handle else None
)
last_comm_handle = bwd_perm_a2a_handle
# launch shared expert grad allgather here
if tp_size > 1:
_, backward_ag_shared, backward_ag_shared_handle = async_all_gather(
bwd_layer_graph.shared_experts_graph[1].grad, tp_group, event=last_comm_handle,
stream=torch.npu.current_stream() if last_comm_handle else None
)
else:
backward_ag_shared = bwd_layer_graph.shared_experts_graph[1].grad
backward_ag_shared_handle = None
# Grouped MLP dw computation
if args.moe_zero_memory == 'level0':
# restore fc1 input for dw computation
with torch.no_grad():
bwd_recomp_perm_a2a_handle.wait()
bwd_recomp_perm_a2a_handle = None
recompute_fc1_input, _ = permute(bwd_perm_a2a_out, global_input_tokens_local_experts_indices)
bwd_perm_a2a_out.untyped_storage().resize_(0)
bwd_dispached_input.untyped_storage().resize_(recompute_fc1_input.untyped_storage().size())
bwd_dispached_input.untyped_storage().copy_(recompute_fc1_input.untyped_storage())
recompute_fc1_input.untyped_storage().resize_(0)
WeightGradStore.pop()
with checkpoint_context:
detached_perm_a2a_out = detach_tensor(perm_a2a_out)
dispached_input = alltoall_token_perm2(fwd_layer.mlp.token_dispatcher, detached_perm_a2a_out)
perm_a2a_out.untyped_storage().resize_(0)
# Grouped MLP Forward
detached_dispached_input = detach_tensor(dispached_input)
(expert_output, fc1_output, act_out), _ = fwd_layer.mlp.experts(detached_dispached_input, tokens_per_expert)
if args.moe_zero_memory == 'level0':
dispached_input.untyped_storage().resize_(0)
recompute_needed_tensors = [dispached_input, fc1_output, act_out, probs, indices,
fwd_layer.mlp.token_dispatcher.global_input_tokens_local_experts_indices]
else:
if should_recompute_activation(fwd_layer.layer_number):
recompute_needed_tensors = [None, fc1_output, act_out, None, None, None]
else:
recompute_needed_tensors = [None, None, None, None, None, None]
detached_expert_output = detach_tensor(expert_output)
# Token Unpermutaion Forward
unperm1_out = alltoall_token_unperm1(fwd_layer.mlp.token_dispatcher, detached_expert_output, None)
expert_output.untyped_storage().resize_(0)
if rs_shared_experts_handle is not None:
rs_shared_experts_handle.wait()
rs_shared_experts_handle = None
share_experts_graph.untyped_storage().resize_(0)
bwd_perm_a2a_handle.wait()
bwd_perm_a2a_handle = None
if backward_ag_shared_handle is not None:
# ensure tp comm is not overlaped with alltoall comm
backward_ag_shared_handle.wait()
backward_ag_shared_handle = None
# move shared experts backward before unpermF all2all to avoid tp comm colision.
WeightGradStore.start_decouple()
run_graph_backward(bwd_layer_graph.shared_experts_graph, backward_ag_shared, keep_grad=True) # dw computation
WeightGradStore.end_decouple()
with checkpoint_context:
# launch async all2all in the middle of attention graph backward
if tp_size > 1 and a2a_hooked_on_attention:
set_async_alltoall_inputs(
unperm1_out, fwd_layer.mlp.token_dispatcher.input_splits, fwd_layer.mlp.token_dispatcher.output_splits, ep_group
)
else:
_, unperm_a2a_out, unperm_a2a_handle = async_all_to_all(
unperm1_out,
fwd_layer.mlp.token_dispatcher.input_splits,
fwd_layer.mlp.token_dispatcher.output_splits,
ep_group
)
run_graph_backward(bwd_layer_graph.perm1_graph, perm1_out_grad)
perm1_out_grad.untyped_storage().resize_(0)
run_graph_backward(bwd_layer_graph.router_graph)
run_graph_backward(bwd_layer_graph.pre_mlp_layernorm_graph, keep_graph=True)
WeightGradStore.start_decouple()
run_graph_backward(bwd_layer_graph.attn_graph, keep_grad=True)
WeightGradStore.end_decouple()
if tp_size > 1 and a2a_hooked_on_attention:
unperm_a2a_out, unperm_a2a_handle = get_async_alltoall_outputs()
if next_bwd_layer_graph is not None and getattr(next_bwd_layer_graph, 'is_moe_layer', False):
run_graph_backward(next_bwd_layer_graph.unperm2_graph, bwd_layer_graph.layer_input.grad, keep_graph=True)
unperm_a2a_handle.wait()
unperm_a2a_handle = None
unperm1_out.untyped_storage().resize_(0)
next_layer_output_grad, next_bwd_unperm_a2a_handle = bwd_layer_graph.layer_input.grad, None
if next_bwd_layer_graph is not None and getattr(next_bwd_layer_graph, 'is_moe_layer', False):
_, next_layer_output_grad, next_bwd_unperm_a2a_handle = async_all_to_all(
next_bwd_layer_graph.unperm_a2a_graph[1].grad,
next_bwd_layer_graph.output_splits,
next_bwd_layer_graph.input_splits,
ep_group
)
with checkpoint_context:
detached_unperm_a2a_out = detach_tensor(unperm_a2a_out)
route_expert_output, _ = alltoall_token_unperm2(fwd_layer.mlp.token_dispatcher, detached_unperm_a2a_out)
if hasattr(fwd_layer.mlp, 'shared_experts') and fwd_layer.mlp.shared_experts is not None:
detached_shared_expert_output = detach_tensor(shared_expert_output)
mlp_output = route_expert_output + detached_shared_expert_output
shared_expert_output.untyped_storage().resize_(0)
else:
detached_shared_expert_output = None
share_experts_graph = None
mlp_output = route_expert_output
if recomp_norm:
mlp_output.register_hook(fwd_layer.norm_ckpt2.recompute)
with fwd_layer.bias_dropout_add_exec_handler():
hidden_states = fwd_layer.mlp_bda(fwd_layer.training, fwd_layer.config.bias_dropout_fusion)(
(mlp_output, None), residual2, fwd_layer.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
# handle fwd p2p communication
next_iter_input_tensor, fwd_p2p_handles = None, None
fwd_pp_comm_params = pp_comm_params
if is_p2p_comm_needed(fwd_pp_comm_params):
next_iter_input_tensor, fwd_p2p_handles = p2p_comm_helper(fwd_pp_comm_params, output)
# handle bwd p2p communication
next_iter_output_tensor_grad, bwd_p2p_handles = None, None
if is_p2p_comm_needed(bwd_pp_comm_params):
next_iter_output_tensor_grad, bwd_p2p_handles = p2p_comm_helper(bwd_pp_comm_params, bwd_layer_graph.layer_input.grad)
WeightGradStore.pop()
saved_tensors = (
(attention_graph, detached_attention_out),
(pre_mlp_layernorm_output, detached_mlp_input),
(probs, probs_detached),
(perm1_out, None), # perm1 graph
(None, detached_perm_a2a_out),
(dispached_input, detached_dispached_input), # perm2 graph
(expert_output, detached_expert_output), # grouped mlp graph
(unperm1_out, None), # unperm1 graph
(None, detached_unperm_a2a_out),
(output, None), # unperm2 graph
(share_experts_graph, detached_shared_expert_output),
detached_layer_input
)
graph = LayerGraph(
saved_tensors, recompute_needed_tensors,
fwd_layer.mlp.token_dispatcher.input_splits, fwd_layer.mlp.token_dispatcher.output_splits, fwd_layer,
checkpointed=checkpoint
)
for tensor in bwd_layer_graph.recompute_needed_tensors:
if tensor is not None:
tensor.untyped_storage().resize_(0)
return (output, context, graph,
(next_layer_output_grad, next_bwd_unperm_a2a_handle),
P2PCommOutput(next_iter_input_tensor, next_iter_output_tensor_grad, fwd_p2p_handles, bwd_p2p_handles, bwd_layer_graph.layer_input.grad))
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from typing import List
from contextlib import nullcontext
from megatron.training import get_args
from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
from mindspeed.core.transformer.transformer_block import NoopTransformerLayer
from .modules.utils import (
detach_tensor, LayerGraph, P2PCommParams
)
from .transformer_layer import transformer_layer_backward
def transformer_block_forward(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
):
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(
inp=hidden_states,
requires_grad=True,
keep_graph=True,
)
rng_context = nullcontext()
fp8_context = nullcontext()
assert not self.config.enable_cuda_graph
layer_graphs = []
args = get_args()
with rng_context and fp8_context:
for l_no, layer in enumerate(self.layers):
checkpoint = False
if self.config.recompute_granularity == 'full' and self.training:
if self.config.recompute_method == 'block':
recompute_skip_num_layers = 0
if self.config.fp8 and not hidden_states.requires_grad:
recompute_skip_num_layers += 1
if (l_no >= recompute_skip_num_layers and l_no < self.config.recompute_num_layers + recompute_skip_num_layers):
checkpoint = True
if self.config.recompute_method == 'uniform':
assert self.config.recompute_num_layers == 1
checkpoint = True
hidden_states, context, saved_graphs = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
checkpoint=checkpoint
)
layer_graphs.append(saved_graphs)
# Final layer norm.
if self.post_process and self.post_layer_norm and self.final_layernorm is not None:
detached_hidden_states = detach_tensor(hidden_states)
layer_graphs[-1].unperm2_graph = (layer_graphs[-1].unperm2_graph[0], detached_hidden_states)
hidden_states = self.final_layernorm(detached_hidden_states)
return (hidden_states, layer_graphs)
def transformer_block_backward(
block_output_grad,
layer_graphs: List[LayerGraph],
):
# should call backward fisrt for final_layernorm and postprocess grad
layer_output_grad = block_output_grad
while len(layer_graphs) > 0:
layer_graph = layer_graphs.pop(-1)
layer_output_grad = transformer_layer_backward(layer_output_grad, layer_graph)
return layer_output_grad
def transformer_block_forward_backward_overlaping(
fwd_block,
hidden_states,
attention_mask,
bwd_block_output_grad,
bwd_block_graphs: List[LayerGraph],
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
pp_comm_params: P2PCommParams = None,
bwd_pp_comm_params: P2PCommParams = None,
):
if not fwd_block.pre_process:
# See set_input_tensor()
hidden_states = fwd_block.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(
inp=hidden_states,
requires_grad=True,
keep_graph=True,
)
rng_context = nullcontext()
fp8_context = nullcontext()
assert not fwd_block.config.enable_cuda_graph
fwd_layer_graphs = []
bwd_layer_output_grad = bwd_block_output_grad
bwd_unperm_a2a_handle = None
fwd_hidden_states, fwd_context = hidden_states, context
with (((rng_context and fp8_context))):
for l_no, fwd_layer in enumerate(fwd_block.layers):
checkpoint = False
if fwd_block.config.recompute_granularity == 'full' and fwd_block.training:
if fwd_block.config.recompute_method == 'block':
recompute_skip_num_layers = 0
if fwd_block.config.fp8 and not hidden_states.requires_grad:
recompute_skip_num_layers += 1
if (l_no >= recompute_skip_num_layers and l_no < fwd_block.config.recompute_num_layers + recompute_skip_num_layers):
checkpoint = True
if fwd_block.config.recompute_method == 'uniform':
assert fwd_block.config.recompute_num_layers == 1
checkpoint = True
bwd_layer_graph = bwd_block_graphs.pop(-1)
cur_p2p_params = pp_comm_params
cur_bwd_p2p_params = bwd_pp_comm_params
if l_no != len(fwd_block.layers) - 1 or len(bwd_block_graphs) > 0:
# no need to excute pp communication in the intermediate layers
cur_p2p_params = P2PCommParams()
cur_bwd_p2p_params = P2PCommParams()
next_bwd_layer_graph = None
if (len(bwd_block_graphs) > 0 and
not bwd_block_graphs[-1].checkpointed and
l_no != len(fwd_block.layers) - 1 and
not isinstance(fwd_block.layers[l_no + 1], NoopTransformerLayer)
):
next_bwd_layer_graph = bwd_block_graphs[-1]
fwd_hidden_states, fwd_context, fwd_layer_graph, \
(bwd_layer_output_grad, bwd_unperm_a2a_handle), \
pp_comm_output = \
fwd_layer(
fwd_hidden_states,
attention_mask,
bwd_layer_output_grad,
bwd_layer_graph=bwd_layer_graph,
bwd_unperm_a2a_handle=bwd_unperm_a2a_handle,
next_bwd_layer_graph=next_bwd_layer_graph,
context=fwd_context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
pp_comm_params=cur_p2p_params,
bwd_pp_comm_params=cur_bwd_p2p_params,
checkpoint=checkpoint
)
fwd_layer_graphs.append(fwd_layer_graph)
# Final layer norm.
if fwd_block.post_process and fwd_block.post_layer_norm and fwd_block.final_layernorm is not None:
detached_hidden_states = detach_tensor(fwd_hidden_states)
fwd_layer_graphs[-1].unperm2_graph = (fwd_layer_graphs[-1].unperm2_graph[0], detached_hidden_states)
fwd_hidden_states = fwd_block.final_layernorm(detached_hidden_states)
return (fwd_hidden_states, fwd_layer_graphs), bwd_layer_output_grad, pp_comm_output
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from contextlib import nullcontext
import torch
from mindspeed.core.transformer.transformer_block import NoopTransformerLayer
from .modules.utils import (
NoopLayerGraph, LayerGraph, is_p2p_comm_needed,
p2p_comm_helper, P2PCommOutput, P2PCommParams
)
from .overlap_funcs import (
transformer_layer_forward_moe,
transformer_layer_forward_dense,
transformer_layer_forward_noop,
transformer_layer_backward_moe,
transformer_layer_backward_dense,
transformer_layer_backward_noop,
transformer_layer_forward_moe_backward_moe_overlaping,
transformer_layer_forward_dense_backward_moe_overlaping,
transformer_layer_forward_moe_backward_dense_overlaping,
transformer_layer_forward_dense_backward_dense_overlaping,
)
def transformer_layer_forward(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
use_orig_layer_forward=False,
checkpoint=False
):
if checkpoint:
checkpoint_context = torch.no_grad()
else:
checkpoint_context = nullcontext()
with checkpoint_context:
layer_forward_func = None
if use_orig_layer_forward:
from mindspeed.core.pipeline_parallel.fp_overlap.megatron_adaptor import get_orig_transformer_layer_forward_func
# for mtp transformer layer forward
layer_forward_func = get_orig_transformer_layer_forward_func()
return layer_forward_func(
self, hidden_states, attention_mask,
context, context_mask, rotary_pos_emb, inference_params, packed_seq_params
)
elif isinstance(self, NoopTransformerLayer):
layer_forward_func = transformer_layer_forward_noop
elif hasattr(self.mlp, 'experts'):
layer_forward_func = transformer_layer_forward_moe
else:
layer_forward_func = transformer_layer_forward_dense
return layer_forward_func(
self, hidden_states, attention_mask,
context, context_mask, rotary_pos_emb, inference_params, packed_seq_params, checkpoint=checkpoint
)
def transformer_layer_backward(
layer_output_grad,
layer_graph
):
if layer_graph.checkpointed:
with torch.enable_grad():
_, _, restored_layer_graph = transformer_layer_forward(
layer_graph.layer, layer_graph.layer_input, *layer_graph.layer_inputs, checkpoint=False
)
restored_layer_graph.unperm2_graph = (restored_layer_graph.unperm2_graph[0], layer_graph.unperm2_graph[1])
layer_graph = restored_layer_graph
if isinstance(layer_graph, NoopLayerGraph):
return transformer_layer_backward_noop(layer_output_grad, layer_graph)
elif layer_graph.is_moe_layer:
return transformer_layer_backward_moe(layer_output_grad, layer_graph)
else:
return transformer_layer_backward_dense(layer_output_grad, layer_graph)
def transformer_layer_forward_backward_overlaping(
fwd_layer,
hidden_states,
attention_mask,
bwd_layer_output_grad=None,
bwd_layer_graph: LayerGraph = None,
bwd_unperm_a2a_handle=None,
next_bwd_layer_graph: LayerGraph = None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
pp_comm_params: P2PCommParams = None,
bwd_pp_comm_params: P2PCommParams = None,
use_orig_layer_forward=False,
checkpoint=False
):
if isinstance(fwd_layer, NoopTransformerLayer) or bwd_layer_graph is None or isinstance(bwd_layer_graph, NoopLayerGraph):
# no f&w overlaping
if bwd_layer_graph is None:
out = transformer_layer_forward(
fwd_layer, hidden_states, attention_mask, context, context_mask, rotary_pos_emb,
inference_params, packed_seq_params, use_orig_layer_forward, checkpoint=checkpoint
)
if len(out) > 2 and checkpoint:
out[2].record_layer_inputs(
attention_mask, context, context_mask, rotary_pos_emb,
inference_params, packed_seq_params, use_orig_layer_forward
)
return out
else:
output, context, graph = transformer_layer_forward(
fwd_layer, hidden_states, attention_mask, context, context_mask, rotary_pos_emb,
inference_params, packed_seq_params, use_orig_layer_forward, checkpoint=checkpoint
)
# handle fwd p2p communication
next_iter_input_tensor, fwd_p2p_handles = None, None
fwd_pp_comm_params = pp_comm_params
if is_p2p_comm_needed(fwd_pp_comm_params):
next_iter_input_tensor, fwd_p2p_handles = p2p_comm_helper(fwd_pp_comm_params, output)
bwd_input_grad = transformer_layer_backward(bwd_layer_output_grad, bwd_layer_graph)
next_iter_output_tensor_grad, bwd_p2p_handles = None, None
if bwd_input_grad is not None:
# handle bwd p2p communication
if is_p2p_comm_needed(bwd_pp_comm_params):
next_iter_output_tensor_grad, bwd_p2p_handles = p2p_comm_helper(bwd_pp_comm_params, bwd_input_grad)
if checkpoint:
graph.record_layer_inputs(
attention_mask, context, context_mask, rotary_pos_emb,
inference_params, packed_seq_params, use_orig_layer_forward
)
return (output, context, graph,
(bwd_input_grad, None),
P2PCommOutput(next_iter_input_tensor, next_iter_output_tensor_grad, fwd_p2p_handles, bwd_p2p_handles, bwd_input_grad))
else:
fb_overlap_func = None
if hasattr(fwd_layer.mlp, 'experts') and bwd_layer_graph.is_moe_layer:
fb_overlap_func = transformer_layer_forward_moe_backward_moe_overlaping
elif hasattr(fwd_layer.mlp, 'experts') and not bwd_layer_graph.is_moe_layer:
fb_overlap_func = transformer_layer_forward_moe_backward_dense_overlaping
elif not hasattr(fwd_layer.mlp, 'experts') and bwd_layer_graph.is_moe_layer:
fb_overlap_func = transformer_layer_forward_dense_backward_moe_overlaping
elif not hasattr(fwd_layer.mlp, 'experts') and not bwd_layer_graph.is_moe_layer:
fb_overlap_func = transformer_layer_forward_dense_backward_dense_overlaping
else:
raise AssertionError('Check Layer Spec, f&b overlap func is not supported!')
if bwd_layer_graph.checkpointed:
_, _, bwd_layer_graph = transformer_layer_forward(
bwd_layer_graph.layer, bwd_layer_graph.layer_input, *bwd_layer_graph.layer_inputs, checkpoint=False
)
out = fb_overlap_func(
fwd_layer, hidden_states, attention_mask, bwd_layer_output_grad, bwd_layer_graph, bwd_unperm_a2a_handle,
next_bwd_layer_graph, context, context_mask, rotary_pos_emb, inference_params,
packed_seq_params, pp_comm_params, bwd_pp_comm_params, checkpoint=checkpoint
)
if checkpoint:
out[2].record_layer_inputs(
attention_mask, context, context_mask, rotary_pos_emb,
inference_params, packed_seq_params, use_orig_layer_forward
)
return out
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import contextlib
from typing import Callable, Iterator, List, Optional, Union
from functools import partial
import torch
from torch.autograd.variable import Variable
from megatron.training import get_timers, get_args
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import (
deallocate_output_tensor,
set_current_microbatch,
check_first_val_step,
clear_embedding_activation_buffer,
finish_embedding_wgrad_compute,
custom_backward
)
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
drain_embedding_wgrad_compute,
get_attr_wrapped_model,
get_model_config,
get_model_type,
)
from .gpt_model import gpt_model_backward
from .modules.utils import P2PCommParams
LOSS_BACKWARD_SCALE = torch.tensor(1.0)
# Types
Shape = Union[List[int], torch.Size]
def forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
extra_block_kwargs=None
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model, extra_block_kwargs)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch, extra_block_kwargs
)
num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage():
if not collect_non_loss_data:
next_info = None
if isinstance(output_tensor, tuple):
# use pp overlaping,
if len(output_tensor) == 2:
output_tensor, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor, model_graph, next_info = output_tensor
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
output_tensor = (output_tensor, model_graph, next_info) if next_info is not None else (
output_tensor, model_graph)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(LOSS_BACKWARD_SCALE)
if config.grad_scale_func is not None
else torch.tensor(1.0)
)
# Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config, model_graph=None):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if config.timers is not None:
config.timers('backward-compute', log_level=2).start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if output_tensor_grad[0] is None and config.grad_scale_func is not None and model_graph is None:
output_tensor[0] = config.grad_scale_func(output_tensor[0])
if config.deallocate_pipeline_outputs:
if model_graph is None:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
layer_output_grad = gpt_model_backward(output_tensor_grad[0], model_graph)
else:
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = []
if model_graph is not None:
input_tensor_grad.append(layer_output_grad)
else:
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
if config.timers is not None:
config.timers('backward-compute').stop()
return input_tensor_grad
def forward_step_vpp_overlap(data_iterator, model, extra_block_kwargs=None):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
from pretrain_gpt import get_batch, loss_func
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers('batch-generator').stop()
if extra_block_kwargs is not None:
# excute forward backward overlaping
output_tensor, model_graph, pp_comm_output = \
model(tokens, position_ids, attention_mask, labels=labels, extra_block_kwargs=extra_block_kwargs)
return (output_tensor, model_graph, pp_comm_output), partial(loss_func, loss_mask)
else:
output_tensor, model_graph = model(tokens, position_ids, attention_mask, labels=labels)
return (output_tensor, model_graph), partial(loss_func, loss_mask)
def forward_backward_pipelining_with_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
assert isinstance(
data_iterator, list
), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
# should overide forward step func with forward_step_vpp_overlap
forward_step_func = forward_step_vpp_overlap
config = get_model_config(model[0])
if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
# Needed only when gradients are finalized in M-Core
if config.finalize_model_grads_func is not None and not forward_only:
embedding_module = clear_embedding_activation_buffer(config, model)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
# Disable async grad reductions
no_sync_func = config.no_sync_func
if isinstance(no_sync_func, list):
def multi_no_sync():
stack = contextlib.ExitStack()
for model_chunk_no_sync_func in config.no_sync_func:
stack.enter_context(model_chunk_no_sync_func())
return stack
no_sync_func = multi_no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
config.grad_sync_func = [config.grad_sync_func for _ in model]
if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
config.param_sync_func = [config.param_sync_func for _ in model]
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
# Model chunk IDs with synchronized grads
synchronized_model_chunks = set()
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
model_graphs = [[] for _ in range(len(model))]
logits_inputs = []
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
forward_data_store = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
if num_microbatches % pipeline_parallel_size != 0:
msg = f'number of microbatches ({num_microbatches}) is not divisible by '
msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
msg += 'when using interleaved schedule'
raise RuntimeError(msg)
model_type = get_model_type(model[0])
if model_type == ModelType.encoder_and_decoder:
raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
if decoder_seq_length is not None and decoder_seq_length != seq_length:
raise RuntimeError(
"Interleaving is not supported with a different decoder sequence length."
)
tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
total_num_microbatches = num_microbatches * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = total_num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if num_microbatches == pipeline_parallel_size:
num_warmup_microbatches = total_num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
# add one more warmup microbatches for 1f1b overlaping
num_warmup_microbatches += 1
num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
# Synchronize params for first two model chunks
if config.param_sync_func is not None:
config.param_sync_func[0](model[0].parameters())
config.param_sync_func[1](model[1].parameters())
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def get_microbatch_id_in_model_chunk(iteration_id, forward):
"""Helper method to get the microbatch_id within model chunk given the iteration number."""
assert forward
iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks)
microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + (
iteration_id % pipeline_parallel_size
)
return microbatch_id_in_model_chunk
def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == 0:
return microbatch_id_in_group % pipeline_parallel_size == 0
else:
return False
def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == num_microbatch_groups - 1:
return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
else:
return False
def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activations_microbatch,
extra_block_kwargs=None, backward_k=None):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if config.param_sync_func is not None:
param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
if (
param_sync_microbatch_id < total_num_microbatches
and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
):
param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
if 1 < param_sync_chunk_id < num_model_chunks:
config.param_sync_func[param_sync_chunk_id](
model[param_sync_chunk_id].parameters()
)
# forward step
if parallel_state.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(
first_val_step,
forward_only,
is_first_microbatch_for_model_chunk(microbatch_id),
),
current_microbatch=current_microbatch,
extra_block_kwargs=extra_block_kwargs
)
if isinstance(output_tensor, tuple):
if len(output_tensor) == 2:
output_tensor_, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor_, model_graph, pp_comm_output = output_tensor
if parallel_state.is_pipeline_last_stage():
logits_inputs.append(model_graph.layer_graphs[-1].unperm2_graph[1])
model_graphs[model_chunk_id].append(model_graph)
else:
output_tensor_ = output_tensor
output_tensors[model_chunk_id].append(output_tensor_)
if backward_k is not None:
backward_chunk_id = get_model_chunk_id(backward_k, forward=False)
input_tensors[backward_chunk_id].pop(0)
output_tensors[backward_chunk_id].pop(0)
output_tensor_grads[backward_chunk_id].pop(0)
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def backward_step_helper(microbatch_id, logits_bwd=False):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch grad synchronization (default)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
if not logits_bwd:
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
model_graph = model_graphs[model_chunk_id].pop(0)
else:
input_tensor = logits_inputs.pop(0)
output_tensor = output_tensors[model_chunk_id][0]
output_tensor_grad = output_tensor_grads[model_chunk_id][0]
model_graph = None
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config, model_graph
)
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if config.grad_sync_func is not None:
grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
grad_sync_microbatch_id
):
grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
enable_grad_sync()
config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
synchronized_model_chunks.add(grad_sync_chunk_id)
disable_grad_sync()
return input_tensor_grad
def check_pipeline_stage(forward_k, backward_k):
send_next = not (get_model_chunk_id(forward_k,
forward=True) == num_model_chunks - 1 and pipeline_parallel_rank == parallel_state.get_pipeline_model_parallel_world_size() - 1)
send_prev = not (get_model_chunk_id(backward_k, forward=False) == 0 and pipeline_parallel_rank == 0)
recv_prev = not (get_model_chunk_id(forward_k + 1, forward=True) == 0 and pipeline_parallel_rank == 0)
if forward_k + 1 >= total_num_microbatches:
recv_prev = False
recv_next = not (get_model_chunk_id(backward_k + 1,
forward=False) == num_model_chunks - 1 and pipeline_parallel_rank == parallel_state.get_pipeline_model_parallel_world_size() - 1)
return P2PCommParams(send_next=send_next, recv_prev=recv_prev), P2PCommParams(send_prev=send_prev, recv_next=recv_next)
# Run warmup forward passes.
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
fwd_wait_handles = None
bwd_wait_handles = None
P2PCommParams.tensor_shape = tensor_shape
P2PCommParams.config = config
for k in range(num_warmup_microbatches):
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
cur_model_chunk_id = get_model_chunk_id(k, forward=True)
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
current_microbatch = get_microbatch_id_in_model_chunk(k, forward=True)
output_tensor = forward_step_helper(
k, current_microbatch, checkpoint_activations_microbatch, backward_k=None
)
if isinstance(output_tensor, tuple):
# use pp overlaping,
if len(output_tensor) == 2:
output_tensor, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor, model_graph, pp_comm_output = output_tensor
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (total_num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if not config.overlap_p2p_comm:
if isinstance(output_tensor, tuple):
if len(output_tensor) == 2:
output_tensor, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor, model_graph, pp_comm_output = output_tensor
if parallel_state.is_pipeline_last_stage():
model_graph, logits_input = model_graph
logits_input.append(logits_input)
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
else:
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
output_tensor_grad,
bwd_wait_handles,
) = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
forward_k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
# 按照绝对mbid 判断chunk无需修改
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True)
if config.overlap_p2p_comm:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
extra_block_kwargs = {}
backward_k = k
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
input_tensor_grad = backward_step_helper(backward_k, logits_bwd=True)
assert input_tensor_grad is not None, "logits backward should not be None"
extra_block_kwargs.setdefault('bwd_model_grad', input_tensor_grad)
else:
# input_tensor_grad通过pp通信获得
output_tensor_grad = output_tensor_grads[backward_model_chunk_id][0]
extra_block_kwargs.setdefault('bwd_model_grad', output_tensor_grad)
fwd_pp_comm_params, bwd_pp_comm_params = check_pipeline_stage(forward_k, backward_k)
extra_block_kwargs.setdefault('bwd_model_graph', model_graphs[backward_model_chunk_id].pop(0))
extra_block_kwargs.setdefault('pp_comm_params', fwd_pp_comm_params)
extra_block_kwargs.setdefault('bwd_pp_comm_params', bwd_pp_comm_params)
output_tensor = forward_step_helper(
forward_k, current_microbatch, checkpoint_activations_microbatch, extra_block_kwargs,
backward_k=backward_k
)
output_tensor, model_graph, pp_comm_output = output_tensor
input_tensor, fwd_wait_handles = pp_comm_output.input_tensor, pp_comm_output.fwd_wait_handles
output_tensor_grad, bwd_wait_handles = pp_comm_output.output_tensor_grad, pp_comm_output.bwd_wait_handles
if fwd_pp_comm_params.recv_prev:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if bwd_pp_comm_params.recv_next:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if config.overlap_p2p_comm and bwd_wait_handles is not None:
for wait_handle in bwd_wait_handles:
wait_handle.wait()
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config)
)
for k in range(num_microbatches_remaining, total_num_microbatches):
chunk_id = get_model_chunk_id(k, False)
parallel_state.set_virtual_pipeline_model_parallel_rank(chunk_id)
if parallel_state.is_pipeline_last_stage():
input_tensor_grad = backward_step_helper(k, logits_bwd=True)
output_tensor_grads[chunk_id].append(input_tensor_grad)
output_tensors[chunk_id].pop(0)
output_tensors[chunk_id].append(None) # dummy output tensors
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (total_num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
)
)
# Launch any remaining grad reductions.
enable_grad_sync()
if config.grad_sync_func is not None:
for model_chunk_id in range(num_model_chunks):
if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
if config.finalize_model_grads_func is not None and not forward_only:
# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute(config, embedding_module)
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config.finalize_model_grads_func(
model, total_num_tokens if config.calculate_per_token_loss else None
)
if config.timers is not None:
config.timers('forward-backward').stop()
return forward_data_store
import torch
from functools import wraps
from dcu_megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
def forward_step_wrapper(fn):
@wraps(fn)
def wrapper(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs,
):
output, num_tokens = fn(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs
)
if not isinstance(input_tensor, list):
# unwrap_output_tensor True
output_tensor = output
else:
output_tensor = output[0]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
return output, num_tokens
return wrapper
\ No newline at end of file
from .layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init,
)
\ No newline at end of file
import os
import socket
import warnings
from functools import wraps
from typing import Callable, List, Optional
try:
......@@ -10,35 +9,19 @@ except ImportError:
raise ImportError("flux is NOT installed")
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.training import print_rank_0
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import (
is_torch_min_version,
prepare_input_tensors_for_wgrad_compute
)
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
VocabParallelEmbedding,
)
from megatron.core.utils import prepare_input_tensors_for_wgrad_compute
from megatron.core.tensor_parallel.mappings import (
_reduce,
copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
)
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
......@@ -47,8 +30,6 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd,
custom_bwd,
dist_all_gather_func,
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
)
from dcu_megatron.core.utils import is_flux_min_version
......@@ -60,109 +41,6 @@ except ImportError:
_grad_accum_fusion_available = False
def vocab_parallel_embedding_init(
self,
num_embeddings: int,
embedding_dim: int,
*,
init_method: Callable,
reduce_scatter_embeddings: bool = False,
config: ModelParallelConfig,
skip_weight_param_allocation: bool = False
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.reduce_scatter_embeddings = reduce_scatter_embeddings
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
(self.vocab_start_index, self.vocab_end_index) = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.deterministic_mode = config.deterministic_mode
# Allocate weights and initialize.
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
else:
self.weight = None
@torch.compile(mode='max-autotune-no-cudagraphs')
def vocab_parallel_embedding_forward(self, input_, weight=None):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
if self.deterministic_mode:
output_parallel = weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
if self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
def get_tensor_model_parallel_node_size(group=None):
""" 获取节点数
"""
......
import warnings
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer import ModuleSpec
from .multi_token_predictor import (
MultiTokenPredicationSubmodules,
MultiTokenPredictor
)
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TENorm
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False):
"""
Multi Token Predication Layer Specification.
"""
use_te = use_te & HAVE_TE
mtp_spec = ModuleSpec(
module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules(
embedding=None,
enorm=TENorm if use_te else LNImpl,
hnorm=TENorm if use_te else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te else LNImpl,
output_layer=None,
)
)
return mtp_spec
import os
import logging
from dataclasses import dataclass
from typing import Union, Optional, Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel, InferenceParams
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
from ...tensor_parallel.random import CheckpointWithoutOutput
@dataclass
class MultiTokenPredicationSubmodules:
embedding: Union[ModuleSpec, type] = None
output_layer: Union[ModuleSpec, type] = None
eh_proj: Union[ModuleSpec, type] = None
enorm: Union[ModuleSpec, type] = None
hnorm: Union[ModuleSpec, type] = None
transformer_layer: Union[ModuleSpec, type] = None
final_layernorm: Union[ModuleSpec, type] = None
class MultiTokenPredictor(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MultiTokenPredicationSubmodules,
vocab_size: int,
max_sequence_length: int,
layer_number: int = 1,
hidden_dropout: float = None,
pre_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
share_mtp_embedding_and_output_weight=True,
recompute_mtp_norm=False,
recompute_mtp_layer=False,
add_output_layer_bias=False
):
super().__init__(config=config)
self.config = config
self.submodules = submodules
self.layer_number = layer_number
self.hidden_dropout = hidden_dropout
self.hidden_size = self.config.hidden_size
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.position_embedding_type = position_embedding_type
# share with main model
self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight
self.recompute_layer_norm = recompute_mtp_norm
self.recompute_mtp_layer = recompute_mtp_layer
self.add_output_layer_bias = add_output_layer_bias
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
skip_weight_param_allocation=self.pre_process and self.share_mtp_embedding_and_output_weight
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
use_cpu_initialization=self.config.use_cpu_initialization,
)
self.enorm = build_module(
self.submodules.enorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.hnorm = build_module(
self.submodules.hnorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.eh_proj = build_module(
self.submodules.eh_proj,
self.hidden_size + self.hidden_size,
self.hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='eh',
)
self.transformer_layer = build_module(
self.submodules.transformer_layer,
config=self.config,
)
if self.submodules.final_layernorm:
self.final_layernorm = build_module(
self.submodules.final_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None
if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
column_parallel_linear_impl = FluxColumnParallelLinear
else:
column_parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = column_parallel_linear_impl(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def forward(
self,
hidden_input_ids: Tensor,
embed_input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
embeding_weight: Optional[torch.Tensor] = None,
output_weight: Optional[torch.Tensor] = None,
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input = self.embedding(
input_ids=embed_input_ids,
position_ids=position_ids,
weight=embeding_weight,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
rotary_seq_len = decoder_input.size(0)
if self.config.sequence_parallel:
rotary_seq_len *= self.config.tensor_model_parallel_size
rotary_seq_len *= self.config.context_parallel_size
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
if self.recompute_layer_norm:
self.enorm_ckpt = CheckpointWithoutOutput()
enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input)
self.hnorm_ckpt = CheckpointWithoutOutput()
hnorm_output = self.hnorm_ckpt.checkpoint(self.hnorm, False, hidden_input_ids)
else:
enorm_output = self.enorm(decoder_input)
hnorm_output = self.hnorm(hidden_input_ids)
# [s, b, h] -> [s, b, 2h]
hidden_states = torch.concat(
[hnorm_output,
enorm_output],
dim=-1
)
if self.recompute_layer_norm:
self.enorm_ckpt.discard_output()
self.hnorm_ckpt.discard_output()
hidden_states.register_hook(self.enorm_ckpt.recompute)
hidden_states.register_hook(self.hnorm_ckpt.recompute)
# hidden_states -> [s, b, h]
hidden_states, _ = self.eh_proj(hidden_states)
if self.config.tensor_model_parallel_size > 1:
hidden_states = tensor_parallel.gather_from_tensor_model_parallel_region(hidden_states)
if self.config.sequence_parallel:
hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states)
if self.recompute_mtp_layer:
hidden_states, context = tensor_parallel.checkpoint(
self.transformer_layer,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
None,
None,
rotary_pos_emb,
inference_params,
packed_seq_params,
)
else:
hidden_states, _ = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
# Final layer norm.
if self.final_layernorm is not None:
if self.recompute_layer_norm:
self.finalnorm_ckpt = CheckpointWithoutOutput()
finalnorm_output = self.finalnorm_ckpt.checkpoint(self.final_layernorm, False, hidden_states)
else:
finalnorm_output = self.final_layernorm(hidden_states)
else:
finalnorm_output = hidden_states
logits, _ = self.output_layer(finalnorm_output, weight=output_weight)
if self.recompute_layer_norm:
self.finalnorm_ckpt.discard_output()
logits.register_hook(self.finalnorm_ckpt.recompute)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return hidden_states, loss
def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
if self.config.cross_entropy_loss_fusion:
loss = fused_vocab_parallel_cross_entropy(logits, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
\ No newline at end of file
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from torch import Tensor
from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import (
all_gather_last_dim_from_tensor_parallel_region,
scatter_to_sequence_parallel_region,
)
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint, make_viewless_tensor
SUPPORTED_ATTN_MASK = [
AttnMaskType.padding,
AttnMaskType.causal,
AttnMaskType.no_mask,
AttnMaskType.padding_causal,
]
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDelayedScaling,
TENorm,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
from megatron.core.transformer.torch_norm import WrappedTorchNorm
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def tie_word_embeddings_state_dict(
sharded_state_dict: ShardedStateDict, word_emb_weight: Tensor, word_emb_weight_key: str
) -> None:
"""tie the embedding of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
word_emb_weight (Tensor): weight of the word embedding.
word_emb_weight_key (str): key of the word embedding in the sharded state dict.
Returns: None, acts in-place
"""
mtp_word_emb_replica_id = (
1, # copy of embedding in pre processing stage
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
assert word_emb_weight_key in sharded_state_dict
del sharded_state_dict[word_emb_weight_key]
sharded_state_dict[word_emb_weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=word_emb_weight,
key=word_emb_weight_key,
replica_id=mtp_word_emb_replica_id,
allow_shape_mismatch=True,
)
def tie_output_layer_state_dict(
sharded_state_dict: ShardedStateDict, output_layer_weight: Tensor, output_layer_weight_key: str
) -> None:
"""tie the output layer of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
output_layer_weight (Tensor): weight of the output layer.
output_layer_weight_key (str): key of the output layer in the sharded state dict.
Returns: None, acts in-place
"""
mtp_output_layer_replica_id = (
1, # copy of output layer in post processing stage
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
assert output_layer_weight_key in sharded_state_dict
del sharded_state_dict[output_layer_weight_key]
sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=output_layer_weight,
key=output_layer_weight_key,
replica_id=mtp_output_layer_replica_id,
allow_shape_mismatch=True,
)
def roll_tensor(tensor, shifts=-1, dims=-1):
"""Roll the tensor input along the given dimension(s).
Inserted elements are set to be 0.0.
"""
rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
rolled_tensor.select(dims, shifts).fill_(0)
return rolled_tensor, rolled_tensor.sum()
class MTPLossLoggingHelper:
"""Helper class for logging MTP losses."""
tracker = {}
@staticmethod
def save_loss_to_tracker(
loss: torch.Tensor,
layer_number: int,
num_layers: int,
reduce_group: torch.distributed.ProcessGroup = None,
avg_group: torch.distributed.ProcessGroup = None,
):
"""Save the mtp loss for logging.
Args:
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
"""
# Skip mtp loss logging if layer_number is None.
if layer_number is None:
return
tracker = MTPLossLoggingHelper.tracker
if "values" not in tracker:
tracker["values"] = torch.zeros(num_layers, device=loss.device)
tracker["values"][layer_number] += loss.detach()
tracker["reduce_group"] = reduce_group
tracker["avg_group"] = avg_group
def clean_loss_in_tracker():
"""Clear the mtp losses."""
tracker = MTPLossLoggingHelper.tracker
tracker["values"].zero_()
tracker["reduce_group"] = None
tracker["avg_group"] = None
def reduce_loss_in_tracker():
"""Collect and reduce the mtp losses across ranks."""
tracker = MTPLossLoggingHelper.tracker
if "values" not in tracker:
return
values = tracker["values"]
# Reduce mtp losses across ranks.
if tracker.get('reduce_group') is not None:
torch.distributed.all_reduce(values, group=tracker.get('reduce_group'))
if tracker.get('avg_group') is not None:
torch.distributed.all_reduce(
values, group=tracker['avg_group'], op=torch.distributed.ReduceOp.AVG
)
def track_mtp_metrics(loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None):
"""Track the Multi-Token Prediction (MTP) metrics for logging."""
MTPLossLoggingHelper.reduce_loss_in_tracker()
tracker = MTPLossLoggingHelper.tracker
if "values" not in tracker:
return
mtp_losses = tracker["values"] * loss_scale
mtp_num_layers = mtp_losses.shape[0]
for i in range(mtp_num_layers):
name = f"mtp_{i+1} loss"
loss = mtp_losses[i]
if total_loss_dict is not None:
total_loss_dict[name] = loss
if writer is not None:
writer.add_scalar(name, loss, iteration)
if wandb_writer is not None:
wandb_writer.log({f"{name}": loss}, iteration)
MTPLossLoggingHelper.clean_loss_in_tracker()
@dataclass
class MultiTokenPredictionLayerSubmodules:
"""
Dataclass for specifying the submodules of a MultiTokenPrediction module.
Args:
hnorm (Union[ModuleSpec, type]): Specification or instance of the
hidden states normalization to be applied.
enorm (Union[ModuleSpec, type]): Specification or instance of the
embedding normalization to be applied.
eh_proj (Union[ModuleSpec, type]): Specification or instance of the
linear projection to be applied.
transformer_layer (Union[ModuleSpec, type]): Specification
or instance of the transformer block to be applied.
"""
enorm: Union[ModuleSpec, type] = None
hnorm: Union[ModuleSpec, type] = None
eh_proj: Union[ModuleSpec, type] = None
transformer_layer: Union[ModuleSpec, type] = None
layer_norm: Union[ModuleSpec, type] = None
def get_mtp_layer_spec(
transformer_layer_spec: ModuleSpec, use_transformer_engine: bool
) -> ModuleSpec:
"""Get the MTP layer spec.
Returns:
ModuleSpec: Module specification with TE modules
"""
if use_transformer_engine:
assert HAVE_TE, "transformer_engine should be installed if use_transformer_engine is True"
layer_norm_impl = TENorm
column_parallel_linear_impl = TEColumnParallelLinear
else:
layer_norm_impl = LNImpl
column_parallel_linear_impl = ColumnParallelLinear
mtp_layer_spec = ModuleSpec(
module=MultiTokenPredictionLayer,
submodules=MultiTokenPredictionLayerSubmodules(
enorm=layer_norm_impl,
hnorm=layer_norm_impl,
eh_proj=column_parallel_linear_impl,
transformer_layer=transformer_layer_spec,
layer_norm=layer_norm_impl,
),
)
return mtp_layer_spec
def get_mtp_layer_offset(config: TransformerConfig) -> int:
"""Get the offset of the MTP layer."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
return 0
def get_mtp_num_layers_to_build(config: TransformerConfig) -> int:
"""Get the number of MTP layers to build."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
if mpu.is_pipeline_last_stage():
return config.mtp_num_layers if config.mtp_num_layers else 0
else:
return 0
class MTPLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that triggers the backward pass and scales the grad for mtp loss."""
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
@staticmethod
def forward(ctx, output: torch.Tensor, mtp_loss: torch.Tensor):
"""Preserve the mtp by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
mtp_loss (torch.Tensor): The mtp loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(mtp_loss)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for mtp loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp loss
gradient.
"""
(mtp_loss,) = ctx.saved_tensors
mtp_loss_backward_scale = MTPLossAutoScaler.main_loss_backward_scale
scaled_mtp_loss_grad = torch.ones_like(mtp_loss) * mtp_loss_backward_scale
return grad_output, scaled_mtp_loss_grad
@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""set the scale of the mtp loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in
matches the scale of the main_loss.
"""
MTPLossAutoScaler.main_loss_backward_scale = scale
class MultiTokenPredictionLayer(MegatronModule):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def __init__(
self,
config: TransformerConfig,
submodules: MultiTokenPredictionLayerSubmodules,
layer_number: int = 1,
):
super().__init__(config=config)
self.sequence_parallel = config.sequence_parallel
self.submodules = submodules
self.layer_number = layer_number
self_attention_spec = self.submodules.transformer_layer.submodules.self_attention
attn_mask_type = self_attention_spec.params.get('attn_mask_type', '')
assert attn_mask_type in SUPPORTED_ATTN_MASK, (
f"Multi-Token Prediction (MTP) is not jet supported with "
+ f"{attn_mask_type} attention mask type."
+ f"The supported attention mask types are {SUPPORTED_ATTN_MASK}."
)
self.enorm = build_module(
self.submodules.enorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.hnorm = build_module(
self.submodules.hnorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# For the linear projection at the (k - 1)-th MTP layer, the input is the concatenation
# of the i-th tocken's hidden states and the (i + K)-th tocken's decoder input,
# so the input's shape is [s, b, 2*h].
# The output will be send to the following transformer layer,
# so the output's shape should be [s, b, h].
self.eh_proj = build_module(
self.submodules.eh_proj,
self.config.hidden_size * 2,
self.config.hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
)
self.transformer_layer = build_module(self.submodules.transformer_layer, config=self.config)
self.final_layernorm = build_module(
self.submodules.layer_norm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
def forward(
self,
decoder_input: Tensor,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: Tensor = None,
):
"""
Perform the forward pass through the MTP layer.
Args:
hidden_states (Tensor): hidden states tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
decoder_input (Tensor): Input tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
At the (k - 1)-th MTP module, the i-th element of decoder input is
the embedding of (i + K)-th tocken.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask for cross-attention context
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable
to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].
Used as an alternative to apply attention mask for TE cuDNN attention.
inference_params (InferenceParams, optional): Parameters for inference-time
optimizations.
packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence
processing.
Returns:
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
assert context is None, f"multi token prediction + cross attention is not yet supported."
assert (
packed_seq_params is None
), f"multi token prediction + sequence packing is not yet supported."
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = TEDelayedScaling(
config=self.config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(
with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red
)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
with rng_context, fp8_context:
decoder_input = self.enorm(decoder_input)
decoder_input = make_viewless_tensor(
inp=decoder_input, requires_grad=True, keep_graph=True
)
hidden_states = self.hnorm(hidden_states)
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True
)
# At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states
# and the (i + K)-th tocken's embedding, and combine them with linear projection.
hidden_states = torch.cat((decoder_input, hidden_states), -1)
hidden_states, _ = self.eh_proj(hidden_states)
# For tensor parallel, all gather after linear_fc.
hidden_states = all_gather_last_dim_from_tensor_parallel_region(hidden_states)
# For sequence parallel, scatter after linear_fc and before transformer layer.
if self.sequence_parallel:
hidden_states = scatter_to_sequence_parallel_region(hidden_states)
hidden_states, _ = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
# Layer norm before shared head layer.
hidden_states = self.final_layernorm(hidden_states)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
return hidden_states
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the multi token prediction layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
return sharded_state_dict
@dataclass
class MultiTokenPredictionBlockSubmodules:
"""
Dataclass for specifying the submodules of a multi token prediction block.
This class defines the structure for configuring the layers, allowing for
flexible and customizable architecture designs.
Args:
layer_specs (List[ModuleSpec], optional): A list of module specifications for
the layers within the multi token prediction block. Each specification typically
defines a complete multi token prediction layer (e.g., shared embedding,
projection matrix, transformer block, shared output head).
"""
layer_specs: List[ModuleSpec] = None
def _get_mtp_block_submodules(
config: TransformerConfig, spec: Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]
) -> MultiTokenPredictionBlockSubmodules:
"""
Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification.
Args:
config (TransformerConfig): Configuration object for the transformer model.
spec (Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]): Specification for the
multi token prediction block submodules.
Can be either a MultiTokenPredictionBlockSubmodules instance or a ModuleSpec.
Returns:
MultiTokenPredictionBlockSubmodules: The submodules for the multi token prediction block.
"""
# Transformer block submodules.
if isinstance(spec, MultiTokenPredictionBlockSubmodules):
return spec
elif isinstance(spec, ModuleSpec):
if issubclass(spec.module, MultiTokenPredictionBlock):
return spec.submodules
else:
raise Exception(f"specialize for {spec.module.__name__}.")
else:
raise Exception(f"specialize for {type(spec).__name__}.")
class MultiTokenPredictionBlock(MegatronModule):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def __init__(
self, config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec]
):
super().__init__(config=config)
self.submodules = _get_mtp_block_submodules(config, spec)
self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor
self._build_layers()
assert len(self.layers) > 0, "MultiTokenPredictionBlock must have at least one layer."
def _build_layers(self):
def build_layer(layer_spec, layer_number):
return build_module(layer_spec, config=self.config, layer_number=layer_number)
self.layers = torch.nn.ModuleList(
[
build_layer(layer_spec, i + 1)
for i, layer_spec in enumerate(self.submodules.layer_specs)
]
)
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
hidden_states: Tensor,
attention_mask: Tensor,
labels: Tensor = None,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: Tensor = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
loss_mask: Optional[Tensor] = None,
embedding=None,
output_layer=None,
output_weight: Optional[torch.Tensor] = None,
compute_language_model_loss=None,
) -> Tensor:
"""
Perform the forward pass through all of the MTP modules.
Args:
hidden_states (Tensor): Hidden states for input token with the shape [s, b, h]
where s is the sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
Returns:
(Tensor): The mtp loss tensor of shape [b, s].
"""
assert (
labels is not None
), f"labels should not be None for calculating multi token prediction loss."
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(labels)
hidden_states_main_model = hidden_states
for layer_number in range(len(self.layers)):
# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1)
# embedding
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
# norm, linear projection and transformer
hidden_states = self.layers[layer_number](
decoder_input=decoder_input,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
# output
mtp_logits, _ = output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
labels, _ = roll_tensor(labels, shifts=-1, dims=-1)
loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1)
mtp_loss = compute_language_model_loss(labels, mtp_logits)
mtp_loss = loss_mask * mtp_loss
if self.training:
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_tensor_and_context_parallel_group(),
)
mtp_loss_scale = self.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states_main_model = MTPLossAutoScaler.apply(
hidden_states_main_model, mtp_loss_scale * mtp_loss
)
else:
hidden_states_main_model = MTPLossAutoScaler.apply(
hidden_states_main_model, mtp_loss_scale * mtp_loss / num_tokens
)
return hidden_states_main_model
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the multi token prediction module.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction module.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
layer_prefix = f'{prefix}layers.'
for layer in self.layers:
offset = get_mtp_layer_offset(self.config)
sharded_prefix = f'{layer_prefix}{layer.layer_number - 1 }.'
state_dict_prefix = f'{layer_prefix}{layer.layer_number - 1 - offset}.'
sharded_pp_offset = []
layer_sharded_state_dict = layer.sharded_state_dict(
state_dict_prefix, sharded_pp_offset, metadata
)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)
return sharded_state_dict
......@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "num_nextn_predict_layers", 0) > 0:
if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
......
from functools import wraps
from dataclasses import dataclass
from megatron.training import get_args
from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
@dataclass
class ExtraTransformerConfig:
def transformer_config_post_init_wrapper(fn):
@wraps(fn)
def wrapper(self):
fn(self)
args = get_args()
"""Number of Multi-Token Prediction (MTP) Layers."""
self.mtp_num_layers = args.mtp_num_layers
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self.mtp_loss_scaling_factor = args.mtp_loss_scaling_factor
##################
# multi-token prediction
# flux
##################
num_nextn_predict_layers: int = 0
"""The number of multi-token prediction layers"""
self.flux_transpose_weight = args.flux_transpose_weight
mtp_loss_scale: float = 0.3
"""Multi-token prediction loss scale"""
return wrapper
recompute_mtp_norm: bool = False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer: bool = False
"""Whether to recompute mtp layer"""
@dataclass
class ExtraTransformerConfig:
##################
# multi-token prediction
##################
mtp_num_layers: Optional[int] = None
"""Number of Multi-Token Prediction (MTP) Layers."""
share_mtp_embedding_and_output_weight: bool = False
"""share embedding and output weight with mtp layer."""
mtp_loss_scaling_factor: Optional[float] = None
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
##################
# flux
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment