Commit 6a579b17 authored by dongcl's avatar dongcl
Browse files

rewrite combined_1f1b

parent e103a256
......@@ -10,7 +10,7 @@ from torch.autograd.variable import Variable
from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
......@@ -20,13 +20,14 @@ from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def make_viewless(e):
"""make_viewless util func"""
"""Make_viewless util func"""
e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True)
return e
@contextmanager
def stream_acquire_context(stream, event):
"""Stream acquire context"""
event.wait(stream)
try:
yield
......@@ -34,8 +35,29 @@ def stream_acquire_context(stream, event):
event.record(stream)
class FakeScheduleNode:
"""A placeholder node in the computation graph that simply passes through inputs and outputs.
This class is used as a no-op node in the scheduling system when a real computation node
is not needed but the interface must be maintained. It simply returns its inputs unchanged
in both forward and backward passes.
"""
def forward(self, inputs):
"""Passes through inputs unchanged in the forward pass."""
return inputs
def backward(self, outgrads):
"""Passes through gradients unchanged in the backward pass."""
return outgrads
class ScheduleNode:
"""base node for fine-grained schedule"""
"""Base node for fine-grained scheduling.
This class represents a computational node in the pipeline schedule.
It handles the execution of forward and backward operations on a stream.
"""
def __init__(
self,
......@@ -43,24 +65,30 @@ class ScheduleNode:
stream,
event,
backward_func=None,
free_inputs=False,
memory_strategy=None,
name="schedule_node",
):
"""Initialize a schedule node.
Args:
forward_func (callable): Function to execute during forward pass
stream (torch.cuda.Stream): CUDA stream for computation
event (torch.cuda.Event): Event for synchronization
backward_func (callable, optional): Function for backward pass
memory_strategy (MemoryManagementStrategy, optional): Strategy for memory management
name (str): Name of the node for debugging
"""
self.name = name
self.forward_func = forward_func
self.backward_func = backward_func
self.backward_func = backward_func if backward_func else self.default_backward_func
self.stream = stream
self.event = event
self.free_inputs = free_inputs
self.memory_strategy = memory_strategy or NoOpMemoryStrategy()
self.inputs = None
self.outputs = None
def default_backward_func(self, outputs, output_grad):
# Handle scalar output
if output_grad is None:
assert outputs.numel() == 1, "implicit grad requires scalar output."
output_grad = torch.ones_like(outputs, memory_format=torch.preserve_format)
"""Default backward function"""
Variable._execution_engine.run_backward(
tensors=outputs,
grad_tensors=output_grad,
......@@ -72,20 +100,16 @@ class ScheduleNode:
)
return output_grad
def forward(self, inputs=(), stream_wait_event=None, stream_record_event=None):
"""schedule node forward"""
def forward(self, inputs=()):
"""Schedule node forward"""
if not isinstance(inputs, tuple):
inputs = (inputs,)
return self._forward(*inputs, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event)
return self._forward(*inputs)
def _forward(self, *inputs, stream_wait_event=None, stream_record_event=None):
def _forward(self, *inputs):
with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} forward")
with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs]
for i, input in enumerate(self.inputs):
if input is not None:
......@@ -100,50 +124,35 @@ class ScheduleNode:
data = tuple([make_viewless(e) if isinstance(e, Tensor) else e for e in data])
self.output = data
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop()
if self.free_inputs:
for input in inputs:
input.record_stream(self.stream)
input.untyped_storage().resize_(0)
# Handle inputs using the memory strategy
self.memory_strategy.handle_inputs(inputs, self.stream)
return self.output
def get_output(self):
"""get the forward output"""
"""Get the forward output"""
return self.output
def backward(self, output_grad, stream_wait_event=None, stream_record_event=None):
"""schedule node backward"""
def backward(self, output_grad):
"""Schedule node backward"""
if not isinstance(output_grad, tuple):
output_grad = (output_grad,)
return self._backward(*output_grad, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event)
return self._backward(*output_grad)
def _backward(self, *output_grad, stream_wait_event=None, stream_record_event=None):
def _backward(self, *output_grad):
with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} backward")
with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
outputs = self.output
if not isinstance(outputs, tuple):
outputs = (outputs,)
assert len(outputs) == len(
output_grad
), f"{len(outputs)} of {type(outputs[0])} vs {len(output_grad)} of {type(output_grad[0])}"
if self.backward_func is not None:
output_grad = self.backward_func(outputs, output_grad)
else:
output_grad = self.default_backward_func(outputs, output_grad)
if stream_record_event is not None:
stream_record_event.record(self.stream)
assert len(outputs) == len(output_grad), (
f"{len(outputs)} of {type(outputs[0])} is not equal to "
f"{len(output_grad)} of {type(output_grad[0])}"
)
output_grad = self.backward_func(outputs, output_grad)
torch.cuda.nvtx.range_pop()
# output_grad maybe from another stream
......@@ -153,7 +162,7 @@ class ScheduleNode:
return self.get_grad()
def get_grad(self):
"""get the grad of inputs"""
"""Get the grad of inputs"""
grad = tuple([e.grad if e is not None else None for e in self.inputs])
# clear state
self.inputs = None
......@@ -165,7 +174,7 @@ class ScheduleNode:
class AbstractSchedulePlan(ABC):
"""to use combined 1f1b, model must implement build_schedule_plan while take the same
"""To use combined 1f1b, model must implement build_schedule_plan while take the same
signature as model forward but return an instance of AbstractSchedulePlan"""
@classmethod
......@@ -197,7 +206,29 @@ def schedule_chunk_1f1b(
post_forward=None,
post_backward=None,
):
"""model level 1f1b fine-grained schedule"""
"""Model level 1f1b fine-grained schedule
This function schedules the forward and backward passes for a chunk of the model.
It takes in the forward schedule plan, backward schedule plan, gradient, and optional
context managers for the forward and backward passes.
Args:
f_schedule_plan (subclass of AbstractSchedulePlan): The forward schedule plan
b_schedule_plan (subclass of AbstractSchedulePlan): The backward schedule plan
grad (Tensor or None): The gradient of the loss function
f_context (VppContextManager or None): The VppContextManager for the forward pass
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (callable or None): The function to call before the forward pass
pre_backward (callable or None): The function to call before the backward pass
post_forward (callable or None): The function to call after the forward pass
post_backward (callable or None): The function to call after the backward pass
Returns:
The output of the forward pass
"""
# Calls fine_grained_schedule.py::ModelChunkSchedulePlan.forward_backward(),
# which calls fine_grained_schedule.py::schedule_chunk_1f1b()
return type(f_schedule_plan or b_schedule_plan).forward_backward(
f_schedule_plan,
b_schedule_plan,
......@@ -216,7 +247,7 @@ _COM_STREAM = None
def set_streams(comp_stream=None, com_stream=None):
"""set the streams for communication and computation"""
"""Set the streams for communication and computation"""
global _COMP_STREAM
global _COM_STREAM
if _COMP_STREAM is not None:
......@@ -234,19 +265,19 @@ def set_streams(comp_stream=None, com_stream=None):
def get_comp_stream():
"""get the stream for computation"""
"""Get the stream for computation"""
global _COMP_STREAM
return _COMP_STREAM
def get_com_stream():
"""get the stream for communication"""
"""Get the stream for communication"""
global _COM_STREAM
return _COM_STREAM
class VppContextManager:
"""a reusable context manager for switch vpp stage"""
"""A reusable context manager for switch vpp stage"""
def __init__(self, vpp_rank):
self.vpp_rank = vpp_rank
......@@ -353,9 +384,17 @@ def forward_backward_step(
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
assert (
checkpoint_activations_microbatch is None
), "checkpoint_activations_microbatch is not supported for combined_1f1b"
if config.combined_1f1b_recipe != "ep_a2a":
raise NotImplementedError(
f"combined_1f1b_recipe {config.combined_1f1b_recipe} not supported yet"
)
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
if config.timers is not None:
if f_model is not None and config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if config.enable_autocast:
......@@ -364,6 +403,8 @@ def forward_backward_step(
context_manager = contextlib.nullcontext()
# forward preprocess
unwrap_output_tensor = False
f_schedule_plan = None
if f_model is not None:
with f_context:
if is_first_microbatch and hasattr(f_model, 'set_is_first_microbatch'):
......@@ -371,7 +412,6 @@ def forward_backward_step(
if current_microbatch is not None:
set_current_microbatch(f_model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
......@@ -381,20 +421,20 @@ def forward_backward_step(
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, f_model)
f_schedule_plan, loss_func = forward_step_func(data_iterator, f_model)
else:
output_tensor, loss_func = forward_step_func(
f_schedule_plan, loss_func = forward_step_func(
data_iterator, f_model, checkpoint_activations_microbatch
)
assert isinstance(
output_tensor, AbstractSchedulePlan
f_schedule_plan, AbstractSchedulePlan
), "first output of forward_step_func must be one instance of AbstractSchedulePlan"
# backward preprocess
unwrap_input_tensor_grad = False
b_schedule_plan = None
if b_model is not None:
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(b_input_tensor, list):
b_input_tensor = [b_input_tensor]
unwrap_input_tensor_grad = True
......@@ -418,9 +458,8 @@ def forward_backward_step(
torch.autograd.backward(b_output_tensor[0], grad_tensors=b_output_tensor_grad[0])
b_output_tensor_grad[0] = loss_node.get_grad()
f_schedule_plan = output_tensor if f_model else None
grad = b_output_tensor_grad[0] if b_model else None
with context_manager:
with context_manager: # autocast context
# schedule forward and backward
output_tensor = schedule_chunk_1f1b(
f_schedule_plan,
......@@ -436,7 +475,7 @@ def forward_backward_step(
# forward post process
num_tokens = None
if f_model:
if f_model is not None:
with f_context:
model_vp_stage = getattr(f_model, "vp_stage", None)
if vp_stage is not None and model_vp_stage is not None:
......@@ -535,7 +574,18 @@ def forward_backward_step(
return output_tensor, num_tokens, input_tensor_grad
def get_default_cls_for_unwrap():
"""Returns the default classes to unwrap from a model.
This function provides a tuple of classes that should be unwrapped from a model
to access the underlying GPTModel instance. It includes DistributedDataParallel
and Float16Module by default, and also attempts to include LegacyFloat16Module
if available for backward compatibility.
Returns:
tuple: A tuple of classes to unwrap from a model.
"""
cls = (DistributedDataParallel, Float16Module)
try:
# legacy should not be used in core, but for backward compatibility, we support it here
......@@ -547,7 +597,9 @@ def get_default_cls_for_unwrap():
def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
"""Unwrap_model DistributedDataParallel and Float16Module wrapped model
to return GPTModel instance
"""
return_list = True
if not isinstance(model, list):
model = [model]
......@@ -556,19 +608,80 @@ def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
assert isinstance(
model_module, GPTModel
), "The final unwrapped model must be a GPTModel instance"
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def wrap_forward_func(config, forward_step_func):
"""wrap the input to forward_step_func, to make forward_step_func return schedule plan"""
def wrap_forward_func(forward_step_func):
"""Wrap the input to forward_step_func.
The wrapped function will return forward_schedule_plan and the loss_function.
"""
def wrapped_func(data_iterator, model):
# Model is unwrapped to get GPTModel instance.
# GPTModel.build_schedule_plan(model_forward_inputs) is called in the forward_step.
# The return value becomes (forward_schedule_plan, loss_function),
# which is used to be (forward_output_tensor, loss_function).
return forward_step_func(data_iterator, unwrap_model(model).build_schedule_plan)
if config.combined_1f1b and config.combined_1f1b_recipe == "ep_a2a":
return wrapped_func
else:
return forward_step_func
return wrapped_func
class MemoryManagementStrategy:
"""Base class for memory management strategies.
Different memory management strategies can be implemented by subclassing this class.
These strategies control how tensors are handled in memory during the computation.
"""
def handle_inputs(self, inputs, stream):
"""Process input tensors after computation.
Args:
inputs (tuple): Input tensors that have been used
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
def handle_outputs(self, outputs, stream):
"""Process output tensors after computation.
Args:
outputs (tuple): Output tensors produced by the computation
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
class NoOpMemoryStrategy(MemoryManagementStrategy):
"""Strategy that performs no memory management operations.
This is the default strategy - it doesn't free any memory.
"""
pass
class FreeInputsMemoryStrategy(MemoryManagementStrategy):
"""Strategy that immediately frees input tensors after they are used.
This strategy is useful for nodes where inputs are no longer needed
after computation, helping to reduce memory usage.
"""
def handle_inputs(self, inputs, stream):
"""Free input tensors by resizing their storage to zero.
Args:
inputs (tuple): Input tensors to be freed
stream (torch.cuda.Stream): Current CUDA stream
"""
for input in inputs:
if input is not None:
input.record_stream(stream)
input.untyped_storage().resize_(0)
......@@ -94,7 +94,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.collect_per_batch_state(state)
self.apply_per_batch_state(origin_state)
def meta_prepare(
def dispatch_preprocess(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
):
self.hidden_shape = hidden_states.shape
......@@ -112,9 +112,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.routing_map = pad_routing_map(self.routing_map, pad_multiple)
tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert
def dispatch_preprocess(self, hidden_states: torch.Tensor, routing_map: torch.Tensor, tokens_per_expert: torch.Tensor):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
......@@ -235,8 +232,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, probs, routing_map)
# Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
......
......@@ -193,7 +193,7 @@ def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[i
class TransformerLayer(MegatronCoreTransformerLayer):
def forward(
def _submodule_attn_router_forward(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
......@@ -209,275 +209,214 @@ class TransformerLayer(MegatronCoreTransformerLayer):
*,
inference_params: Optional[Any] = None,
):
if (
not isinstance(self.mlp, MoELayer)
or not isinstance(self.mlp.token_dispatcher, MoEAlltoAllTokenDispatcher)
):
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
inference_params=inference_params,
)
(
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
)
(tokens_per_expert, global_input_tokens, global_probs) = self._submodule_dispatch_forward(
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
)
(expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward(
tokens_per_expert,
global_input_tokens,
global_probs,
pre_mlp_layernorm_output
)
expert_output = self._submodule_combine_forward(expert_output)[0]
output = self._submodule_post_combine_forward(
expert_output,
shared_expert_output,
mlp_bias,
hidden_states
)
return output, None
def _submodule_attention_forward(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
# todo
inference_context = deprecate_inference_params(inference_context, inference_params)
# Residual connection.
residual = hidden_states
# Optional Input Layer norm
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
nvtx_range_push(suffix="self_attention")
attention_output_with_bias = self.self_attention(
input_layernorm_output,
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
pre_mlp_layernorm_output, residual, context = self._forward_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_context=inference_context,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
nvtx_range_pop(suffix="self_attention")
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="self_attn_bda")
return hidden_states
def _submodule_attention_router_compound_forward(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
hidden_states = self._submodule_attention_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
)
# Optional Layer norm post the cross-attention.
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
pre_mlp_layernorm_output, probs, routing_map
)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert
pre_mlp_layernorm_output, probs, routing_map
)
outputs = [
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
]
return tuple(outputs)
return (tokens_per_expert, permutated_local_input_tokens, permuted_probs, pre_mlp_layernorm_output, residual, context)
def _submodule_attn_router_postprocess(
self,
node,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
pre_mlp_layernorm_output,
residual,
context,
):
node.common_state.residual = node.detach(residual)
if self.mlp.use_shared_expert:
node.common_state.pre_mlp_layernorm_output = node.detach(pre_mlp_layernorm_output)
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
return tokens_per_expert, permutated_local_input_tokens, permuted_probs
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs, state=None):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
tokens_per_expert, global_input_tokens, global_probs = self.mlp.token_dispatcher.dispatch_all_to_all(
tokens_per_expert, permutated_local_input_tokens, permuted_probs
)
return [tokens_per_expert, global_input_tokens, global_probs]
token_dispatcher = self.mlp.token_dispatcher
tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
def _submodule_dense_forward(self, hidden_states):
residual = hidden_states
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return tokens_per_expert, global_input_tokens, global_probs
return output
def _submodule_dispatch_postprocess(self, node, tokens_per_expert, global_input_tokens, global_probs):
return tokens_per_expert, global_input_tokens, global_probs
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output):
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, state=None):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output = None
(dispatched_input, tokens_per_expert, permuted_probs) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
token_dispatcher = self.mlp.token_dispatcher
dispatched_input, tokens_per_expert, permuted_probs = token_dispatcher.dispatch_postprocess(
tokens_per_expert, global_input_tokens, global_probs
)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
expert_output, mlp_bias = self.mlp.experts(
dispatched_tokens, tokens_per_expert, permuted_probs
)
assert mlp_bias is None, f"Bias is not supported in {token_dispatcher.__class__.__name__}"
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return expert_output, shared_expert_output, mlp_bias
assert state is not None
shared_expert_output = self.mlp.shared_experts(state.pre_mlp_layernorm_output)
def _submodule_combine_forward(self, hidden_states):
return [self.mlp.token_dispatcher.combine_all_to_all(hidden_states)]
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
return expert_output, shared_expert_output, mlp_bias
def _submodule_post_combine_forward(
self, expert_output, shared_expert_output, mlp_bias, residual
):
"""
Re-combines the expert outputs (and optional shared_expert_output) into the same order
as the original input tokens, applying any required bias.
"""
output = self.mlp.token_dispatcher.combine_postprocess(expert_output)
def _submodule_mlp_postprocess(self, node, expert_output, shared_expert_output, mlp_bias):
assert mlp_bias is None
node.common_state.pre_mlp_layernorm_output = None
if shared_expert_output is None:
return expert_output
return expert_output, shared_expert_output
def _submodule_combine_forward(self, expert_output, shared_expert_output=None, state=None):
residual = state.residual
token_dispatcher = self.mlp.token_dispatcher
permutated_local_input_tokens = token_dispatcher.combine_all_to_all(expert_output)
output = token_dispatcher.combine_postprocess(permutated_local_input_tokens)
if shared_expert_output is not None:
output += shared_expert_output
mlp_output_with_bias = (output, mlp_bias)
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(
mlp_output_with_bias[0]
)
output = output + shared_expert_output
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="mlp_bda")
mlp_output_with_bias = (output, None)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="mlp_bda")
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return output
def _submodule_attention_dw(self):
self.self_attention.backward_dw()
def _submodule_combine_postprocess(self, node, output):
cur_stream = torch.cuda.current_stream()
node.common_state.residual.record_stream(cur_stream)
node.common_state.residual = None
return output
def _submodule_attention_router_compound_dw(self):
self._submodule_attention_dw()
def _submodule_attn_router_dw(self):
self.self_attention.backward_dw()
def _submodule_mlp_dw(self):
self.mlp.backward_dw()
def _submodule_attn_postprocess(self, node, pre_mlp_layernorm_output, residual, context):
return pre_mlp_layernorm_output, residual
def _submodule_dense_postprocess(self, node, hidden_states):
return hidden_states
def _submodule_not_implemented(self, *args):
raise NotImplementedError("This callable is not implemented.")
def get_submodule_callables(self, chunk_state):
"""
The forward callables take 2 parts of inputs:
1. The ScheduleNode object.
2. The input tensors.
"""
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.token_dispatcher import MoEFlexTokenDispatcher
self.is_moe = isinstance(self.mlp, MoELayer)
self.is_deepep = False
if self.is_moe:
self.is_deepep = isinstance(self.mlp.token_dispatcher, MoEFlexTokenDispatcher)
def get_func_with_default(func, default_func):
if self.is_moe:
return func
return default_func
def callable_wrapper(forward_func, postprocess_func, node, *args):
state = getattr(node, 'common_state', None)
callable_outputs = forward_func(*args, state=state)
if isinstance(callable_outputs, tuple):
outputs = postprocess_func(node, *callable_outputs)
else:
outputs = postprocess_func(node, callable_outputs)
return outputs
attn_func = get_func_with_default(
self._submodule_attn_router_forward, self._forward_attention
)
def attn_wrapper(hidden_states, state=None):
"""
state (Any, optional): Placeholder for submodule callable wrapper.
"""
return attn_func(
hidden_states=hidden_states,
attention_mask=chunk_state.attention_mask,
content=chunk_state.context,
context_mask=chunk_state.context_mask,
rotary_pos_emb=chunk_state.rotary_pos_emb,
rotary_pos_cos=chunk_state.rotary_pos_cos,
rotary_pos_sin=chunk_state.rotary_pos_sin,
attention_bias=chunk_state.attention_bias,
inference_context=chunk_state.inference_context,
packed_seq_params=chunk_state.packed_seq_params,
sequence_len_offset=chunk_state.sequence_len_offset,
inference_params=chunk_state.inference_params,
)
attn_postprocess_func = get_func_with_default(
self._submodule_attn_router_postprocess, self._submodule_attn_postprocess
)
dispatch_func = get_func_with_default(
self._submodule_dispatch_forward, self._submodule_not_implemented
)
dispatch_postprocess_func = get_func_with_default(
self._submodule_dispatch_postprocess, self._submodule_not_implemented
)
mlp_func = get_func_with_default(self._submodule_moe_forward, self._forward_mlp)
mlp_postprocess_func = get_func_with_default(
self._submodule_mlp_postprocess, self._submodule_dense_postprocess
)
combine_func = get_func_with_default(
self._submodule_combine_forward, self._submodule_not_implemented
)
combine_postprocess_func = get_func_with_default(
self._submodule_combine_postprocess, self._submodule_not_implemented
)
attn_forward = partial(callable_wrapper, attn_wrapper, attn_postprocess_func)
dispatch_forward = partial(callable_wrapper, dispatch_func, dispatch_postprocess_func)
mlp_forward = partial(callable_wrapper, mlp_func, mlp_postprocess_func)
combine_forward = partial(callable_wrapper, combine_func, combine_postprocess_func)
callables = TransformerLayerSubmoduleCallables(
attention=SubmoduleCallables(forward=attn_forward, dw=self._submodule_attn_router_dw),
dispatch=SubmoduleCallables(forward=dispatch_forward),
mlp=SubmoduleCallables(forward=mlp_forward, dw=self._submodule_mlp_dw),
combine=SubmoduleCallables(forward=combine_forward),
is_moe=self.is_moe,
is_deepep=self.is_deepep,
)
return callables
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment