Commit 32ee381a authored by dongcl's avatar dongcl
Browse files

a2a overlap

parent 12b56c98
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import contextlib
import weakref
from typing import Optional
from typing import Any, Callable, Optional, Tuple, Union
import torch
from torch import Tensor
from megatron.core.pipeline_parallel.combined_1f1b import (
AbstractSchedulePlan,
FakeScheduleNode,
FreeInputsMemoryStrategy,
NoOpMemoryStrategy,
ScheduleNode,
get_com_stream,
get_comp_stream,
......@@ -19,15 +14,11 @@ from megatron.core.pipeline_parallel.combined_1f1b import (
)
from megatron.core.transformer import transformer_layer
from megatron.core.transformer.module import float16_to_fp32
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState
def weak_method(method):
"""Creates a weak reference to a method to prevent circular references.
This function creates a weak reference to a method and returns a wrapper function
that calls the method when invoked. This helps prevent memory leaks from circular
references.
"""
method_ref = weakref.WeakMethod(method)
del method
......@@ -38,78 +29,24 @@ def weak_method(method):
return wrapped_func
class MemoryStrategyRegistry:
"""Registry for memory management strategies based on node names.
This class centralizes the definition of which memory strategy
should be used for each type of node in the computation graph.
"""
@classmethod
def get_strategy_by_name(cls, name, is_moe, is_deepep):
"""Gets the appropriate memory strategy for a node based on its name and MoE status.
Args:
name: The name of the node, which determines which strategy to use.
is_moe: Whether the node is part of a Mixture of Experts model.
Returns:
The memory strategy to use for the node.
"""
strategies = {
"default": NoOpMemoryStrategy(),
"attn": NoOpMemoryStrategy(), # Attention nodes keep their inputs
"dispatch": (
FreeInputsMemoryStrategy() if not is_deepep else NoOpMemoryStrategy()
), # deepep dispatch inputs share same storage with moe inputs
"mlp": FreeInputsMemoryStrategy(), # MLP nodes free inputs after use
"combine": FreeInputsMemoryStrategy(), # Combine nodes free inputs after use
}
if is_moe:
return strategies.get(name, strategies["default"])
# For dense layers [attn, fake, mlp, fake], the inputs of mlp are required for backward
return NoOpMemoryStrategy()
class PreProcessNode(ScheduleNode):
"""Node responsible for preprocessing operations in the model.
This node handles embedding and rotary positional embedding computations
before the main transformer layers.
"""
def __init__(self, gpt_model, model_chunk_state, event, stream):
"""Initializes a preprocessing node.
Args:
gpt_model: The GPT model instance.
model_chunk_state: State shared across the model chunk.
event: CUDA event for synchronization.
stream: CUDA stream for execution.
"""
super().__init__(weak_method(self.forward_impl), stream, event, name="pre_process")
super().__init__(weak_method(self.forward_impl), stream, event)
self.gpt_model = gpt_model
self.model_chunk_state = model_chunk_state
def forward_impl(self):
"""Implements the forward pass for preprocessing.
This method handles:
1. Decoder embedding computation
2. Rotary positional embedding computation
3. Sequence length offset computation for flash decoding
Returns:
The processed decoder input tensor.
"""
gpt_model = self.gpt_model
decoder_input = self.model_chunk_state.decoder_input
input_ids = self.model_chunk_state.input_ids
position_ids = self.model_chunk_state.position_ids
inference_params = self.model_chunk_state.inference_params
inference_context = self.model_chunk_state.inference_context
packed_seq_params = self.model_chunk_state.packed_seq_params
inference_context = deprecate_inference_params(inference_context, inference_params)
# Decoder embedding.
if decoder_input is not None:
pass
......@@ -118,42 +55,51 @@ class PreProcessNode(ScheduleNode):
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
# TODO(dongcl)
decoder_input = gpt_model.decoder.input_tensor
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if (
gpt_model.position_embedding_type == 'rope'
and not gpt_model.config.multi_latent_attention
):
if not gpt_model.training and gpt_model.config.flash_decode and inference_params:
if gpt_model.position_embedding_type == 'rope' and not gpt_model.config.multi_latent_attention:
if not gpt_model.training and gpt_model.config.flash_decode and inference_context:
assert (
inference_context.is_static_batching()
), "GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = gpt_model.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
gpt_model.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
inference_context.max_sequence_length,
gpt_model.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
)
else:
rotary_seq_len = gpt_model.rotary_pos_emb.get_rotary_seq_len(
inference_params,
gpt_model.decoder,
decoder_input,
gpt_model.config,
packed_seq_params,
inference_context, gpt_model.decoder, decoder_input, gpt_model.config, packed_seq_params
)
rotary_pos_emb = gpt_model.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
elif gpt_model.position_embedding_type == 'mrope' and not gpt_model.config.multi_latent_attention:
if gpt_model.training or not gpt_model.config.flash_decode:
rotary_pos_emb = gpt_model.rotary_pos_emb(position_ids, gpt_model.mrope_section)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if (
(gpt_model.config.enable_cuda_graph or gpt_model.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
and inference_context
and inference_context.is_static_batching()
and not gpt_model.training
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
[inference_context.sequence_len_offset] * inference_context.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
......@@ -169,42 +115,48 @@ class PreProcessNode(ScheduleNode):
class PostProcessNode(ScheduleNode):
"""Node responsible for postprocessing operations in the model.
This node handles final layer normalization and output layer computation
after the main transformer layers.
"""
def __init__(self, gpt_model, model_chunk_state, event, stream):
"""Initializes a postprocessing node.
Args:
gpt_model: The GPT model instance.
model_chunk_state: State shared across the model chunk.
event: CUDA event for synchronization.
stream: CUDA stream for execution.
"""
super().__init__(weak_method(self.forward_impl), stream, event, name="post_process")
super().__init__(weak_method(self.forward_impl), stream, event)
self.gpt_model = gpt_model
self.model_chunk_state = model_chunk_state
def forward_impl(self, hidden_states):
"""Implements the forward pass for postprocessing.
This method handles:
1. Final layer normalization
2. Output layer computation
3. Loss computation if labels are provided
Args:
hidden_states: The hidden states from the transformer layers.
state.input_ids = input_ids
state.position_ids = position_ids
state.attention_mask = attention_mask
state.decoder_input = decoder_input
state.labels = labels
state.inference_context =inference_context
state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output
state.inference_params = inference_params
state.loss_mask = loss_mask
state.context = None
state.context_mask = None
state.attention_bias = None
Returns:
The logits or loss depending on whether labels are provided.
"""
def forward_impl(self, hidden_states):
gpt_model = self.gpt_model
input_ids = self.model_chunk_state.input_ids
position_ids = self.model_chunk_state.position_ids
labels = self.model_chunk_state.labels
loss_mask = self.model_chunk_state.loss_mask
attention_mask = self.model_chunk_state.attention_mask
inference_params= self.model_chunk_state.inference_params
rotary_pos_emb = self.model_chunk_state.rotary_pos_emb
rotary_pos_cos = self.model_chunk_state.rotary_pos_cos
rotary_pos_sin = self.model_chunk_state.rotary_pos_sin
packed_seq_params = self.model_chunk_state.packed_seq_params
sequence_len_offset = self.model_chunk_state.sequence_len_offset
runtime_gather_output = self.model_chunk_state.runtime_gather_output
inference_context = self.model_chunk_state.inference_context
# Final layer norm.
if self.gpt_model.decoder.final_layernorm is not None:
hidden_states = self.gpt_model.decoder.final_layernorm(hidden_states)
if gpt_model.decoder.final_layernorm is not None:
hidden_states = gpt_model.decoder.final_layernorm(hidden_states)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
......@@ -212,73 +164,108 @@ class PostProcessNode(ScheduleNode):
inp=hidden_states, requires_grad=True, keep_graph=True
)
gpt_model = self.gpt_model
runtime_gather_output = self.model_chunk_state.runtime_gather_output
labels = self.model_chunk_state.labels
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss
output_weight = None
if gpt_model.share_embeddings_and_output_weights:
output_weight = gpt_model.shared_embedding_or_output_weight()
if gpt_model.mtp_process:
hidden_states = gpt_model.mtp(
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
loss_mask=loss_mask,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=gpt_model.embedding,
output_layer=gpt_model.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
compute_language_model_loss=gpt_model.compute_language_model_loss,
**(extra_block_kwargs or {}),
)
if (
gpt_model.mtp_process is not None
and getattr(gpt_model.decoder, "main_final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = gpt_model.decoder.main_final_layernorm(hidden_states)
if not gpt_model.post_process:
return hidden_states
if (
not gpt_model.training
and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits
):
hidden_states = hidden_states[-1:, :, :]
logits, _ = gpt_model.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
if has_config_logger_enabled(gpt_model.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(gpt_model.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return float16_to_fp32(logits.transpose(0, 1).contiguous())
loss = float16_to_fp32(gpt_model.compute_language_model_loss(labels, logits))
return logits.transpose(0, 1).contiguous()
loss = gpt_model.compute_language_model_loss(labels, logits)
return loss
class TransformerLayerNode(ScheduleNode):
"""Base class for transformer layer computation nodes.
This class provides common functionality for different types of
transformer layer nodes (attention, MLP, etc.)
"""
def __init__(self, stream, event, state, callables, name="default"):
"""Initialize a transformer layer node.
Args:
stream (torch.cuda.Stream): CUDA stream for execution
event (torch.cuda.Event): Synchronization event
common_state (TransformerLayerState): State shared within a transformer layer
callables (Callable): The callables contain forward and dw function
it's the per_batch_state_context, o.w. nullcontext
name (str): Node name, also used to determine memory strategy
"""
# Get memory strategy based on node name
memory_strategy = MemoryStrategyRegistry.get_strategy_by_name(
name, callables.is_moe, callables.is_deepep
)
def __init__(self, chunk_state, common_state, layer, stream, event, free_inputs=False):
super().__init__(
weak_method(self.forward_impl),
stream,
event,
weak_method(self.backward_impl),
memory_strategy=memory_strategy,
name=name,
free_inputs=free_inputs,
)
self.common_state = state
self.callables = callables
# layer state
self.common_state = common_state
# model chunk state
self.chunk_state = chunk_state
self.layer = layer
self.detached = tuple()
self.before_detached = tuple()
def detach(self, t):
"""Detaches a tensor and stores it for backward computation."""
detached = make_viewless(t).detach()
detached.requires_grad = t.requires_grad
self.before_detached = self.before_detached + (t,)
self.detached = self.detached + (detached,)
return detached
def forward_impl(self, *args):
"""Implements the forward pass for the transformer layer node."""
return self.callables.forward(self, *args)
def backward_impl(self, outputs, output_grad):
"""Implements the backward pass for the transformer layer node."""
detached_grad = tuple([e.grad for e in self.detached])
grads = output_grad + detached_grad
self.default_backward_func(outputs + self.before_detached, grads)
......@@ -287,84 +274,197 @@ class TransformerLayerNode(ScheduleNode):
# return grads for record stream
return grads
class MoeAttnNode(TransformerLayerNode):
def forward_impl(self, hidden_states):
attention_mask = self.chunk_state.attention_mask
context = self.chunk_state.context
rotary_pos_emb = self.chunk_state.rotary_pos_emb
rotary_pos_cos = self.chunk_state.rotary_pos_cos
rotary_pos_sin = self.chunk_state.rotary_pos_sin
attention_bias = self.chunk_state.attention_bias
inference_context = self.chunk_state.inference_context
packed_seq_params = self.chunk_state.packed_seq_params
sequence_len_offset = self.chunk_state.sequence_len_offset
inference_params = self.chunk_state.inference_params
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
(
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
probs,
) = self.layer._submodule_attention_router_compound_forward(
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
inference_params=inference_params,
)
self.common_state.tokens_per_expert = tokens_per_expert
# detached here
self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
return permutated_local_input_tokens, permuted_probs
def dw(self):
"""Computes the weight gradients for the transformer layer node."""
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.callables.dw()
self.layer._submodule_attention_router_compound_dw()
class TransformerLayerState:
"""State shared within a transformer layer.
This class holds state that is shared between different nodes
within a transformer layer.
"""
class MoeDispatchNode(TransformerLayerNode):
pass
def forward_impl(self, permutated_local_input_tokens, permuted_probs):
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
inputs = permutated_local_input_tokens
tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(
self.common_state.tokens_per_expert, permutated_local_input_tokens, permuted_probs
)
# release tensor not used by backward
# inputs.untyped_storage().resize_(0)
self.common_state.tokens_per_expert = = tokens_per_expert
return global_input_tokens, global_probs
class ModelChunkSate:
"""State shared across a model chunk.
This class holds state that is shared between different components
of a model chunk, such as input tensors, parameters, and configuration.
"""
class MoeMlPNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens, global_probs):
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
self.common_state.tokens_per_expert, global_input_tokens, global_prob, pre_mlp_layernorm_output
)
assert mlp_bias is None
pass
# pre_mlp_layernorm_output used
self.common_state.pre_mlp_layernorm_output = None
return expert_output, shared_expert_output
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_mlp_dw()
class TransformerLayerSchedulePlan:
"""Schedule plan for a transformer layer.
This class organizes the computation nodes for a transformer layer,
including attention, MLP, dispatch, and combine nodes.
"""
def __init__(self, layer, event, chunk_state, comp_stream, com_stream):
"""Initializes a transformer layer schedule plan.
Args:
layer (TransformerLayer): The transformer layer to schedule.
event (torch.cuda.Event): CUDA event for synchronization.
chunk_state (ModelChunkState): State shared across the model chunk.
comp_stream (torch.cuda.Stream): CUDA stream for computation.
com_stream (torch.cuda.Stream): CUDA stream for communication.
"""
self.common_state = TransformerLayerState()
# get callables for transformer layer
attn_callable, dispatch_callable, mlp_callable, combine_callable = (
layer.get_submodule_callables(chunk_state).as_array()
)
# Create nodes for different operations in the layer
# Each node type has a predefined name that determines its memory strategy
self.attn = TransformerLayerNode(
comp_stream, event, self.common_state, attn_callable, name="attn"
)
self.mlp = TransformerLayerNode(
comp_stream, event, self.common_state, mlp_callable, name="mlp"
)
if attn_callable.is_moe:
self.dispatch = TransformerLayerNode(
com_stream, event, self.common_state, dispatch_callable, name="dispatch"
class MoeCombineNode(TransformerLayerNode):
def forward_impl(self, expert_output, shared_expert_output):
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual = self.common_state.residual
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
permutated_local_input_tokens = token_dispatcher.combine_all_to_all(
expert_output
)
self.combine = TransformerLayerNode(
com_stream, event, self.common_state, combine_callable, name="combine"
output = self.layer._submodule_post_combine_forward(
permutated_local_input_tokens, shared_expert_output, None, residual
)
else:
self.dispatch = FakeScheduleNode()
self.combine = FakeScheduleNode()
cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream)
self.common_state.probs.record_stream(cur_stream)
self.common_state.residual = None
self.common_state.probs = None
return output
class ModelChunkSchedulePlan(AbstractSchedulePlan):
"""Schedule plan for a model chunk.
class DenseAttnNode(TransformerLayerNode):
def forward_impl(self, hidden_states):
attention_mask = self.chunk_state.attention_mask
rotary_pos_emb = self.chunk_state.rotary_pos_emb
rotary_pos_cos = self.chunk_state.rotary_pos_cos
rotary_pos_sin = self.chunk_state.rotary_pos_sin
attention_bias = self.chunk_state.attention_bias
inference_context = self.chunk_state.inference_context
packed_seq_params = self.chunk_state.packed_seq_params
sequence_len_offset = self.chunk_state.sequence_len_offset
inference_params = self.chunk_state.inference_params
hidden_states = self.layer._submodule_attention_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
)
return hidden_states
This class organizes the computation nodes for a model chunk,
including preprocessing, transformer layers, and postprocessing.
"""
class FakeScheduleNode:
def forward(self, inputs):
return inputs
def backward(self, outgrads):
return outgrads
class DenseMlpNode(TransformerLayerNode):
def forward_impl(self, hidden_states):
return self.layer._submodule_dense_forward(hidden_states)
def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream):
common_state = TransformerLayerState()
attn = DenseAttnNode(chunk_state, common_state, layer, comp_stream, event)
attn.name = "attn"
dispatch = FakeScheduleNode()
mlp = DenseMlpNode(chunk_state, common_state, layer, comp_stream, event)
combine = FakeScheduleNode()
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine)
def build_layer_schedule_plan(layer, event, chunk_state, comp_stream, com_stream):
if not isinstance(layer.mlp, MoELayer):
return build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream)
common_state = TransformerLayerState()
attn = MoeAttnNode(chunk_state, common_state, layer, comp_stream, event)
attn.name = "attn"
dispatch = MoeDispatchNode(chunk_state, common_state, layer, com_stream, event, True)
dispatch.name = "dispatch"
mlp = MoeMlPNode(chunk_state, common_state, layer, comp_stream, event, True)
mlp.name = "mlp"
combine = MoeCombineNode(chunk_state, common_state, layer, com_stream, event, True)
combine.name = "combine"
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine)
class TransformerLayerState(MoEAlltoAllPerBatchState):
pass
class ModelChunkSate:
pass
class TransformerLayerSchedulePlan:
def __init__(self, attn, dispatch, mlp, combine):
self.attn = attn
self.dispatch = dispatch
self.mlp = mlp
self.combine = combine
class ModelChunkSchedulePlan(AbstractSchedulePlan):
def __init__(self):
"""Initializes a model chunk schedule plan."""
super().__init__()
self._pre_process = None
self._post_process = None
......@@ -385,22 +485,7 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
post_forward=None,
post_backward=None,
):
"""Schedules forward and backward passes for model chunks.
Args:
f_schedule_plan (ModelChunkSchedulePlan): Forward schedule plan.
b_schedule_plan (ModelChunkSchedulePlan): Backward schedule plan.
grad (Tensor): Gradient for backward computation.
f_context (VppContextManager or None): The VppContextManager for the forward pass.
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (Callable): Callback for preprocessing in forward pass.
pre_backward (Callable): Callback for preprocessing in backward pass.
post_forward (Callable): Callback for postprocessing in forward pass.
post_backward (Callable): Callback for postprocessing in backward pass.
Returns:
The output of the forward pass.
"""
return schedule_chunk_1f1b(
f_schedule_plan,
b_schedule_plan,
......@@ -415,55 +500,44 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
@property
def event(self):
"""Gets the CUDA event for synchronization."""
return self._event
def record_current_stream(self):
"""Records the current CUDA stream in the event."""
stream = torch.cuda.current_stream()
self.event.record(stream)
def wait_current_stream(self):
"""Waits for the event to complete on the current CUDA stream."""
stream = torch.cuda.current_stream()
self.event.wait(stream)
@property
def pre_process(self):
"""Gets the preprocessing node."""
return self._pre_process
@pre_process.setter
def pre_process(self, value):
"""Sets the preprocessing node."""
self._pre_process = value
@property
def post_process(self):
"""Gets the postprocessing node."""
return self._post_process
@post_process.setter
def post_process(self, value):
"""Sets the postprocessing node."""
self._post_process = value
def get_layer(self, i):
"""Gets the transformer layer at the specified index."""
assert i < self.num_layers()
return self._transformer_layers[i]
def num_layers(self):
"""Gets the number of transformer layers."""
return len(self._transformer_layers)
def add_layer(self, layer):
"""Adds a transformer layer to the schedule plan."""
self._transformer_layers.append(layer)
@property
def state(self):
"""Gets the model chunk state."""
return self._model_chunk_state
......@@ -478,40 +552,24 @@ def schedule_layer_1f1b(
f_context=None,
b_context=None,
):
"""Schedule one-forward-one-backward operations for a single layer.
This function interleaves forward and backward operations to maximize
parallelism and efficiency.
Args:
f_layer (TransformerLayerSchedulePlan): Forward layer (for current microbatch)
b_layer (TransformerLayerSchedulePlan): Backward layer (for previous microbatch)
f_input (Tensor): Input for forward computation
b_grad (Tensor): Gradient for backward computation
pre_forward (Callable): Callback to get forward input if not provided
pre_backward (Callable): Callback to get backward gradient if not provided
pre_backward_dw (Callable): Callback for weight gradient computation
f_context (VppContextManager or None): The VppContextManager for the forward pass.
b_context (VppContextManager or None): The VppContextManager for the backward pass
Returns:
Functions or values for next iteration's computation
"""
f_context = f_context if f_context is not None else contextlib.nullcontext()
b_context = b_context if b_context is not None else contextlib.nullcontext()
if pre_forward is not None:
assert f_input is None
# combine from last iter
f_input = pre_forward()
del pre_forward
if pre_backward is not None:
# attn backward from last iter
assert b_grad is None
b_grad = pre_backward()
del pre_backward
if b_layer is not None:
with b_context:
b_grad = b_layer.combine.backward(b_grad)
......@@ -520,6 +578,7 @@ def schedule_layer_1f1b(
pre_backward_dw()
del pre_backward_dw
if f_layer is not None:
with f_context:
f_input = f_layer.attn.forward(f_input)
......@@ -534,10 +593,13 @@ def schedule_layer_1f1b(
b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw()
if f_layer is not None:
with f_context:
f_input = f_layer.mlp.forward(f_input)
def next_iter_pre_forward():
if f_layer is not None:
with f_context:
......@@ -555,6 +617,7 @@ def schedule_layer_1f1b(
with b_context:
b_layer.attn.dw()
if f_layer and b_layer:
return next_iter_pre_forward, next_iter_pre_backward, next_iter_pre_backward_dw
else:
......@@ -572,32 +635,14 @@ def schedule_chunk_1f1b(
post_forward=None,
post_backward=None,
):
"""Schedules one-forward-one-backward operations for a model chunk.
This function interleaves forward and backward operations across multiple layers
to maximize parallelism and efficiency.
Args:
f_schedule_plan: Forward schedule plan.
b_schedule_plan: Backward schedule plan.
grad: Gradient for backward computation.
f_context: Context for forward computation.
b_context: Context for backward computation.
pre_forward: Callback for preprocessing in forward pass.
pre_backward: Callback for preprocessing in backward pass.
post_forward: Callback for postprocessing in forward pass.
post_backward: Callback for postprocessing in backward pass.
Returns:
The output of the forward pass.
"""
f_context = f_context if f_context is not None else contextlib.nullcontext()
b_context = b_context if b_context is not None else contextlib.nullcontext()
if f_schedule_plan:
# pp output send/receive sync
if pre_forward is not None:
with f_context: # virtual pipeline parallel context
with f_context:
pre_forward()
f_schedule_plan.record_current_stream()
......@@ -617,14 +662,14 @@ def schedule_chunk_1f1b(
if b_schedule_plan is not None:
assert grad is not None
if b_schedule_plan.post_process is not None:
with b_context: # virtual pipeline parallel context
with b_context:
tmp = b_schedule_plan.post_process.backward(grad)
if pre_backward is not None:
# pp grad send receive sync here, safe for now, maybe not safe in the future
with torch.cuda.stream(get_com_stream()):
b_schedule_plan.wait_current_stream()
with b_context: # virtual pipeline parallel context
with b_context:
pre_backward()
b_schedule_plan.record_current_stream()
......@@ -652,9 +697,6 @@ def schedule_chunk_1f1b(
)
torch.cuda.nvtx.range_pop()
# tail forward
f_input = layer_pre_forward()
del layer_pre_forward
# tail backward
grad = layer_pre_backward()
del layer_pre_backward
......@@ -665,12 +707,12 @@ def schedule_chunk_1f1b(
tmp, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad)
torch.cuda.nvtx.range_pop()
# if b_schedule_plan is not None:
# b_schedule_plan.pre_process.backward(grad)
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(grad)
# # tail forward
# f_input = layer_pre_forward()
# del layer_pre_forward
# tail forward
f_input = layer_pre_forward()
del layer_pre_forward
with f_context:
for i in range(overlaped_layers, f_num_layers):
f_layer = f_schedule_plan.get_layer(i)
......@@ -678,8 +720,8 @@ def schedule_chunk_1f1b(
f_input, tmp, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input)
torch.cuda.nvtx.range_pop()
# if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
# f_input = f_schedule_plan.post_process.forward(f_input)
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
# output pp send receive, overlapped with attn backward
if f_schedule_plan is not None and post_forward is not None:
......@@ -687,8 +729,7 @@ def schedule_chunk_1f1b(
f_schedule_plan.wait_current_stream()
post_forward(f_input)
# pp grad send / receive, overlapped with attn dw of cur micro-batch
# and forward attn of next micro-batch
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
if b_schedule_plan is not None and post_backward is not None:
with b_context:
b_schedule_plan.wait_current_stream()
......@@ -698,13 +739,6 @@ def schedule_chunk_1f1b(
layer_pre_backward_dw()
del layer_pre_backward_dw
with f_context:
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
with b_context:
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(grad)
if f_schedule_plan:
f_schedule_plan.wait_current_stream()
if b_schedule_plan:
......@@ -720,32 +754,15 @@ def build_model_chunk_schedule_plan(
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params=None,
packed_seq_params=None,
extra_block_kwargs=None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None
):
"""Builds a schedule plan for a model chunk.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
Args:
model: The model to build a schedule plan for.
input_ids: Input token IDs.
position_ids: Position IDs.
attention_mask: Attention mask.
decoder_input: Decoder input tensor.
labels: Labels for loss computation.
inference_params: Parameters for inference.
packed_seq_params: Parameters for packed sequences.
extra_block_kwargs: Additional keyword arguments for blocks.
runtime_gather_output: Whether to gather output at runtime.
Returns:
The model chunk schedule plan.
"""
comp_stream = get_comp_stream()
comp_stream = torch.cuda.current_stream()
com_stream = get_com_stream()
model_chunk_schedule_plan = ModelChunkSchedulePlan()
event = model_chunk_schedule_plan.event
......@@ -756,23 +773,28 @@ def build_model_chunk_schedule_plan(
state.attention_mask = attention_mask
state.decoder_input = decoder_input
state.labels = labels
state.inference_params = inference_params
state.inference_context =inference_context
state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output
state.inference_params = inference_params
state.loss_mask = loss_mask
state.context = None
state.context_mask = None
state.attention_bias = None
# build preprocess
model_chunk_schedule_plan.pre_process = PreProcessNode(model, state, event, comp_stream)
model_chunk_schedule_plan.pre_process.name = "pre_process"
# build for layers
for layer_idx in range(model.decoder.num_layers_per_pipeline_rank):
layer = model.decoder._get_layer(layer_idx)
layer_plan = TransformerLayerSchedulePlan(layer, event, state, comp_stream, com_stream)
layer_plan = build_layer_schedule_plan(layer, event, state, comp_stream, com_stream)
model_chunk_schedule_plan.add_layer(layer_plan)
# build post process
if model.post_process:
model_chunk_schedule_plan.post_process = PostProcessNode(model, state, event, comp_stream)
model_chunk_schedule_plan.post_process.name = "post_process"
return model_chunk_schedule_plan
......@@ -9,6 +9,7 @@ from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
......@@ -64,11 +65,14 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
):
"""Builds a computation schedule plan for the model.
......@@ -105,10 +109,12 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask,
decoder_input=decoder_input,
labels=labels,
inference_params=inference_params,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
extra_block_kwargs=extra_block_kwargs,
runtime_gather_output=runtime_gather_output,
inference_params=inference_params,
loss_mask=loss_mask,
)
def forward(
......@@ -118,14 +124,16 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
......@@ -137,6 +145,8 @@ class GPTModel(MegatronCoreGPTModel):
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context = deprecate_inference_params(inference_context, inference_params)
# Decoder embedding.
if decoder_input is not None:
pass
......@@ -152,39 +162,64 @@ class GPTModel(MegatronCoreGPTModel):
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params:
if not self.training and self.config.flash_decode and inference_context:
assert (
inference_context.is_static_batching()
), "GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
inference_context.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
inference_context, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
if self.training or not self.config.flash_decode:
rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
and inference_context
and inference_context.is_static_batching()
and not self.training
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
[inference_context.sequence_len_offset] * inference_context.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if (
inference_context is not None
and not self.training
and not has_config_logger_enabled(self.config)
):
decoder_input = WrappedTensor(decoder_input)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
......@@ -193,6 +228,12 @@ class GPTModel(MegatronCoreGPTModel):
**(extra_block_kwargs or {}),
)
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
......@@ -230,6 +271,13 @@ class GPTModel(MegatronCoreGPTModel):
if not self.post_process:
return hidden_states
if (
not self.training
and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits
):
hidden_states = hidden_states[-1:, :, :]
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
......
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import contextlib
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import List, Union
from typing import Any, List, Tuple, Union
import torch
from torch import Tensor
......@@ -11,24 +9,24 @@ from torch.autograd.variable import Variable
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
# Types
Shape = Union[List[int], torch.Size]
def make_viewless(e):
"""Make_viewless util func"""
"""make_viewless util func"""
e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True)
return e
@contextmanager
def stream_acquire_context(stream, event):
"""Stream acquire context"""
event.wait(stream)
try:
yield
......@@ -36,29 +34,8 @@ def stream_acquire_context(stream, event):
event.record(stream)
class FakeScheduleNode:
"""A placeholder node in the computation graph that simply passes through inputs and outputs.
This class is used as a no-op node in the scheduling system when a real computation node
is not needed but the interface must be maintained. It simply returns its inputs unchanged
in both forward and backward passes.
"""
def forward(self, inputs):
"""Passes through inputs unchanged in the forward pass."""
return inputs
def backward(self, outgrads):
"""Passes through gradients unchanged in the backward pass."""
return outgrads
class ScheduleNode:
"""Base node for fine-grained scheduling.
This class represents a computational node in the pipeline schedule.
It handles the execution of forward and backward operations on a stream.
"""
"""base node for fine-grained schedule"""
def __init__(
self,
......@@ -66,30 +43,19 @@ class ScheduleNode:
stream,
event,
backward_func=None,
memory_strategy=None,
free_inputs=False,
name="schedule_node",
):
"""Initialize a schedule node.
Args:
forward_func (callable): Function to execute during forward pass
stream (torch.cuda.Stream): CUDA stream for computation
event (torch.cuda.Event): Event for synchronization
backward_func (callable, optional): Function for backward pass
memory_strategy (MemoryManagementStrategy, optional): Strategy for memory management
name (str): Name of the node for debugging
"""
self.name = name
self.forward_func = forward_func
self.backward_func = backward_func if backward_func else self.default_backward_func
self.backward_func = backward_func
self.stream = stream
self.event = event
self.memory_strategy = memory_strategy or NoOpMemoryStrategy()
self.free_inputs = free_inputs
self.inputs = None
self.outputs = None
def default_backward_func(self, outputs, output_grad):
"""Default backward function"""
Variable._execution_engine.run_backward(
tensors=outputs,
grad_tensors=output_grad,
......@@ -102,7 +68,8 @@ class ScheduleNode:
return output_grad
def forward(self, inputs=()):
"""Schedule node forward"""
"""schedule node forward"""
if not isinstance(inputs, tuple):
inputs = (inputs,)
return self._forward(*inputs)
......@@ -127,17 +94,19 @@ class ScheduleNode:
self.output = data
torch.cuda.nvtx.range_pop()
# Handle inputs using the memory strategy
self.memory_strategy.handle_inputs(inputs, self.stream)
if self.free_inputs:
for input in inputs:
input.record_stream(self.stream)
input.untyped_storage().resize_(0)
return self.output
def get_output(self):
"""Get the forward output"""
"""get the forward output"""
return self.output
def backward(self, output_grad):
"""Schedule node backward"""
"""schedule node backward"""
if not isinstance(output_grad, tuple):
output_grad = (output_grad,)
return self._backward(*output_grad)
......@@ -149,11 +118,13 @@ class ScheduleNode:
outputs = self.output
if not isinstance(outputs, tuple):
outputs = (outputs,)
assert len(outputs) == len(output_grad), (
f"{len(outputs)} of {type(outputs[0])} is not equal to "
f"{len(output_grad)} of {type(output_grad[0])}"
)
output_grad = self.backward_func(outputs, output_grad)
assert len(outputs) == len(
output_grad
), f"{len(outputs)} of {type(outputs[0])} vs {len(output_grad)} of {type(output_grad[0])}"
if self.backward_func is not None:
output_grad = self.backward_func(outputs, output_grad)
else:
output_grad = self.default_backward_func(outputs, output_grad)
torch.cuda.nvtx.range_pop()
# output_grad maybe from another stream
......@@ -163,7 +134,7 @@ class ScheduleNode:
return self.get_grad()
def get_grad(self):
"""Get the grad of inputs"""
"""get the grad of inputs"""
grad = tuple([e.grad if e is not None else None for e in self.inputs])
# clear state
self.inputs = None
......@@ -175,7 +146,7 @@ class ScheduleNode:
class AbstractSchedulePlan(ABC):
"""To use combined 1f1b, model must implement build_schedule_plan while take the same
"""to use combined 1f1b, model must implement build_schedule_plan while take the same
signature as model forward but return an instance of AbstractSchedulePlan"""
@classmethod
......@@ -207,29 +178,7 @@ def schedule_chunk_1f1b(
post_forward=None,
post_backward=None,
):
"""Model level 1f1b fine-grained schedule
This function schedules the forward and backward passes for a chunk of the model.
It takes in the forward schedule plan, backward schedule plan, gradient, and optional
context managers for the forward and backward passes.
Args:
f_schedule_plan (subclass of AbstractSchedulePlan): The forward schedule plan
b_schedule_plan (subclass of AbstractSchedulePlan): The backward schedule plan
grad (Tensor or None): The gradient of the loss function
f_context (VppContextManager or None): The VppContextManager for the forward pass
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (callable or None): The function to call before the forward pass
pre_backward (callable or None): The function to call before the backward pass
post_forward (callable or None): The function to call after the forward pass
post_backward (callable or None): The function to call after the backward pass
Returns:
The output of the forward pass
"""
# Calls fine_grained_schedule.py::ModelChunkSchedulePlan.forward_backward(),
# which calls fine_grained_schedule.py::schedule_chunk_1f1b()
"""model level 1f1b fine-grained schedule"""
return type(f_schedule_plan or b_schedule_plan).forward_backward(
f_schedule_plan,
b_schedule_plan,
......@@ -243,19 +192,30 @@ def schedule_chunk_1f1b(
)
def schedule_chunk_forward(schedule_plan):
"""model level fine-grained forward schedule"""
f_input = schedule_chunk_1f1b(schedule_plan, None, None)
return f_input
def schedule_chunk_backward(schedule_plan, grad):
"""model level fine-grained backward schedule"""
tmp = schedule_chunk_1f1b(None, schedule_plan, grad)
_COMP_STREAM = None
_COM_STREAM = None
def set_streams(comp_stream=None, com_stream=None):
"""Set the streams for communication and computation"""
"""set the streams for communication and computation"""
global _COMP_STREAM
global _COM_STREAM
if _COMP_STREAM is not None:
return
if comp_stream is None:
comp_stream = torch.cuda.current_stream()
comp_stream = torch.cuda.Stream(device="cuda")
if com_stream is None:
com_stream = torch.cuda.Stream(device="cuda")
......@@ -266,19 +226,19 @@ def set_streams(comp_stream=None, com_stream=None):
def get_comp_stream():
"""Get the stream for computation"""
"""get the stream for computation"""
global _COMP_STREAM
return _COMP_STREAM
def get_com_stream():
"""Get the stream for communication"""
"""get the stream for communication"""
global _COM_STREAM
return _COM_STREAM
class VppContextManager:
"""A reusable context manager for switch vpp stage"""
"""a reusable context manager for switch vpp stage"""
def __init__(self, vpp_rank):
self.vpp_rank = vpp_rank
......@@ -316,58 +276,75 @@ def forward_backward_step(
current_microbatch=None,
encoder_decoder_xattn=False,
):
"""Merged forward and backward step for combined_1f1b.
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
Need to accept the argument of both forward_step() and backward_step().
forward_step_func (callable): is wrapped by wrap_forward_func() which is now returning
a forward schedule plan which is an input of schedule_chunk_1f1b function.
f_context (VppContextManager or nullcontext): The context manager for setting vpp ranks.
b_context (VppContextManager or nullcontext): The context manager for setting vpp ranks.
Only exists in 1f1b steady state with p2p overlap.
pre_forward (callable): The function to call before the forward_step.
pre_backward (callable): The function to call before the backward_step.
post_forward (callable): The function to call after the forward_step.
post_backward (callable): The function to call after the backward_step.
forward_step_func (callable):
The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally:
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator):
The data iterator.
model (nn.Module):
The model to perform the forward step on.
num_microbatches (int):
The number of microbatches.
input_tensor (Tensor or list[Tensor]):
The input tensor(s) for the forward step.
forward_data_store (list):
The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object):
The configuration object.
collect_non_loss_data (bool, optional):
Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional):
The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional):
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
Returns:
forward_output_tensor (Tensor or list[Tensor]): The output object(s) from the forward step.
forward_num_tokens (Tensor): The number of tokens.
backward_input_tensor_grad (Tensor): The grad of the input tensor.
Descriptions:
This method merges the forward_step() and backward_step() methods in the schedules.py file.
Assuming that:
def forward_step():
# forward_preprocess()
# forward_compute()
# forward_postprocess()
def backward_step():
# backward_preprocess()
# backward_compute()
# backward_postprocess()
Then the forward_backward_step() method will be:
def forward_backward_step():
# forward_preprocess() // the same as the forward_step()
# GENERATE f_schedule_plan // schedule happens in schedule_chunk_1f1b()
# backward_preprocess() // the same as the backward_step()
# COMBINED_FORWARD_BACKWARD_COMPUTE() // by calling schedule_chunk_1f1b()
# forward_postprocess() // the same as the forward_step()
# backward_postprocess() // the same as the backward_step()
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
assert (
checkpoint_activations_microbatch is None
), "checkpoint_activations_microbatch is not supported for combined_1f1b"
if config.combined_1f1b_recipe != "ep_a2a":
raise NotImplementedError(
f"combined_1f1b_recipe {config.combined_1f1b_recipe} not supported yet"
)
from .schedules import set_current_microbatch
if f_model is not None and config.timers is not None:
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if config.enable_autocast:
......@@ -377,7 +354,6 @@ def forward_backward_step(
# forward preprocess
unwrap_output_tensor = False
f_schedule_plan = None
if f_model is not None:
with f_context:
if is_first_microbatch and hasattr(f_model, 'set_is_first_microbatch'):
......@@ -391,10 +367,15 @@ def forward_backward_step(
set_input_tensor = get_attr_wrapped_model(f_model, "set_input_tensor")
set_input_tensor(input_tensor)
with context_manager: # autocast context
f_schedule_plan, loss_func = forward_step_func(data_iterator, f_model)
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, f_model)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, f_model, checkpoint_activations_microbatch
)
assert isinstance(
f_schedule_plan, AbstractSchedulePlan
output_tensor, AbstractSchedulePlan
), "first output of forward_step_func must be one instance of AbstractSchedulePlan"
# backward preprocess
......@@ -425,8 +406,9 @@ def forward_backward_step(
torch.autograd.backward(b_output_tensor[0], grad_tensors=b_output_tensor_grad[0])
b_output_tensor_grad[0] = loss_node.get_grad()
f_schedule_plan = output_tensor if f_model else None
grad = b_output_tensor_grad[0] if b_model else None
with context_manager: # autocast context
with context_manager:
# schedule forward and backward
output_tensor = schedule_chunk_1f1b(
f_schedule_plan,
......@@ -442,7 +424,7 @@ def forward_backward_step(
# forward post process
num_tokens = None
if f_model is not None:
if f_model:
with f_context:
num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage():
......@@ -511,33 +493,18 @@ def forward_backward_step(
return output_tensor, num_tokens, input_tensor_grad
def get_default_cls_for_unwrap():
"""Returns the default classes to unwrap from a model.
This function provides a tuple of classes that should be unwrapped from a model
to access the underlying GPTModel instance. It includes DistributedDataParallel
and Float16Module by default, and also attempts to include LegacyFloat16Module
if available for backward compatibility.
Returns:
tuple: A tuple of classes to unwrap from a model.
"""
cls = (DistributedDataParallel, Float16Module)
try:
# legacy should not be used in core, but for backward compatibility, we support it here
from megatron.legacy.model import Float16Module as LegacyFloat16Module
cls = cls + (LegacyFloat16Module,)
except:
pass
return cls
def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
"""Unwrap_model DistributedDataParallel and Float16Module wrapped model
to return GPTModel instance
"""
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
return_list = True
if not isinstance(model, list):
model = [model]
......@@ -546,80 +513,19 @@ def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
assert isinstance(
model_module, GPTModel
), "The final unwrapped model must be a GPTModel instance"
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def wrap_forward_func(forward_step_func):
"""Wrap the input to forward_step_func.
The wrapped function will return forward_schedule_plan and the loss_function.
"""
def wrap_forward_func(config, forward_step_func):
"""wrap the input to forward_step_func, to make forward_step_func return schedule plan"""
def wrapped_func(data_iterator, model):
# Model is unwrapped to get GPTModel instance.
# GPTModel.build_schedule_plan(model_forward_inputs) is called in the forward_step.
# The return value becomes (forward_schedule_plan, loss_function),
# which is used to be (forward_output_tensor, loss_function).
return forward_step_func(data_iterator, unwrap_model(model).build_schedule_plan)
return wrapped_func
class MemoryManagementStrategy:
"""Base class for memory management strategies.
Different memory management strategies can be implemented by subclassing this class.
These strategies control how tensors are handled in memory during the computation.
"""
def handle_inputs(self, inputs, stream):
"""Process input tensors after computation.
Args:
inputs (tuple): Input tensors that have been used
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
def handle_outputs(self, outputs, stream):
"""Process output tensors after computation.
Args:
outputs (tuple): Output tensors produced by the computation
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
class NoOpMemoryStrategy(MemoryManagementStrategy):
"""Strategy that performs no memory management operations.
This is the default strategy - it doesn't free any memory.
"""
pass
class FreeInputsMemoryStrategy(MemoryManagementStrategy):
"""Strategy that immediately frees input tensors after they are used.
This strategy is useful for nodes where inputs are no longer needed
after computation, helping to reduce memory usage.
"""
def handle_inputs(self, inputs, stream):
"""Free input tensors by resizing their storage to zero.
Args:
inputs (tuple): Input tensors to be freed
stream (torch.cuda.Stream): Current CUDA stream
"""
for input in inputs:
if input is not None:
input.record_stream(stream)
input.untyped_storage().resize_(0)
if config.combined_1f1b and config.combined_1f1b_recipe == "ep_a2a":
return wrapped_func
else:
return forward_step_func
from megatron.core.transformer.moe.token_dispatcher import _DeepepManager as MegatronCoreDeepepManager
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
def token_permutation(
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
# decouple perbatch state from MoEAlltoAllTokenDispatcher
class MoEAlltoAllPerBatchState:
def __init__(self, build_event=False):
self.num_global_tokens_per_local_expert = None
self.output_splits_tp = None
self.output_splits = None
self.input_splits = None
self.num_out_tokens = None
self.capacity = None
self.preprocess_event = None
self.hidden_shape = None
self.probs = None
self.routing_map = None
self.reversed_local_input_permutation_mapping = None
self.cuda_sync_point = None
self.hidden_shape_before_permute = None
class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState):
state.num_global_tokens_per_local_expert = getattr(
self, "num_global_tokens_per_local_expert", None
)
state.output_splits_tp = getattr(self, "output_splits_tp", None)
state.output_splits = getattr(self, "output_splits", None)
state.input_splits = getattr(self, "input_splits", None)
state.num_out_tokens = getattr(self, "num_out_tokens", None)
state.capacity = getattr(self, "capacity", None)
state.preprocess_event = getattr(self, "preprocess_event", None)
state.hidden_shape = getattr(self, "hidden_shape", None)
state.probs = getattr(self, "probs", None)
state.routing_map = getattr(self, "routing_map", None)
state.reversed_local_input_permutation_mapping = getattr(
self, "reversed_local_input_permutation_mapping", None
)
state.hidden_shape_before_permute = getattr(self, "hidden_shape_before_permute", None)
state.cuda_sync_point = getattr(self, "cuda_sync_point", None)
def apply_per_batch_state(self, state: MoEAlltoAllPerBatchState):
self.num_global_tokens_per_local_expert = state.num_global_tokens_per_local_expert
self.output_splits_tp = state.output_splits_tp
self.output_splits = state.output_splits
self.input_splits = state.input_splits
self.num_out_tokens = state.num_out_tokens
self.capacity = state.capacity
self.preprocess_event = state.preprocess_event
self.hidden_shape = state.hidden_shape
self.probs = state.probs
self.routing_map = state.routing_map
self.reversed_local_input_permutation_mapping = (
state.reversed_local_input_permutation_mapping
)
self.hidden_shape_before_permute = state.hidden_shape_before_permute
self.cuda_sync_point = state.cuda_sync_point
@contextmanager
def per_batch_state_context(self, state: MoEAlltoAllPerBatchState):
origin_state = MoEAlltoAllPerBatchState()
self.collect_per_batch_state(origin_state)
try:
self.apply_per_batch_state(state)
yield
finally:
self.collect_per_batch_state(state)
self.apply_per_batch_state(origin_state)
def meta_prepare(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
):
self.hidden_shape = hidden_states.shape
self.probs = probs
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert
def dispatch_preprocess(self, hidden_states: torch.Tensor, routing_map: torch.Tensor, tokens_per_expert: torch.Tensor):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
......@@ -49,12 +98,15 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
) = permute(
hidden_states,
routing_map,
probs=probs,
probs=self.probs,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
return tokens_per_expert, permutated_local_input_tokens, permuted_probs
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
......@@ -65,6 +117,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
global_probs = all_to_all(
self.ep_group, permuted_probs, self.output_splits, self.input_splits
)
return tokens_per_expert, global_input_tokens, global_probs
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens, global_probs):
if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
......@@ -118,184 +174,137 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
)
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert, global_probs
return global_input_tokens, tokens_per_expert, global_probs
class _DeepepManager(MegatronCoreDeepepManager):
"""
patch megatron _DeepepManager. async
"""
def dispatch(
self,
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) -> torch.Tensor:
# DeepEP only supports float32 probs
if self.token_probs.dtype != torch.float32:
if self.token_probs.dtype in [torch.bfloat16, torch.float16]:
print("DeepEP only supports float32 probs, please set --moe-router-dtype=fp32")
self.token_probs = self.token_probs.float() # downcast or upcast
hidden_states, dispatched_indices, dispatched_probs, num_tokens_per_expert, handle = (
fused_dispatch(
hidden_states,
self.token_indices,
self.token_probs,
self.num_experts,
self.group,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
)
self.handle = handle
self.tokens_per_expert = num_tokens_per_expert
self.dispatched_indices = dispatched_indices
self.dispatched_probs = dispatched_probs
return hidden_states
def combine(
self,
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) -> torch.Tensor:
hidden_states, _ = fused_combine(
hidden_states,
self.group,
self.handle,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
# Release the handle after combine operation
self.handle = None
return hidden_states
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
class MoEFlexTokenDispatcher(MoETokenDispatcher):
"""
Flex token dispatcher using DeepEP.
"""
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
def dispatch_preprocess(
self, hidden_states: torch.Tensor, routing_map: torch.Tensor, probs: torch.Tensor
):
"""
Preprocesses the hidden states and routing information before dispatching tokens to experts.
Args:
hidden_states (torch.Tensor): Input hidden states to be processed
routing_map (torch.Tensor): Map indicating which expert each token should be routed to
probs (torch.Tensor): Routing probabilities for each token-expert pair
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple containing:
- torch.Tensor: Reshaped hidden states
- torch.Tensor: Token probabilities from the communication manager
- None: Placeholder for compatibility
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
self.hidden_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
# Initialize metadata
routing_map, probs = self._initialize_metadata(routing_map, probs)
# Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
self._comm_manager.setup_metadata(routing_map, probs)
return hidden_states, self._comm_manager.token_probs, None
# Permutation 2: Sort tokens by local expert.
global_input_tokens, tokens_per_expert, global_probs = self.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
def dispatch_all_to_all(
self,
hidden_states: torch.Tensor,
probs: torch.Tensor = None,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
):
"""
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
"""
return (
self._comm_manager.dispatch(hidden_states, async_finish, allocate_on_comm_stream),
self._comm_manager.dispatched_probs,
)
return global_input_tokens, tokens_per_expert, global_probs
def dispatch_postprocess(self, hidden_states: torch.Tensor):
"""
Post-processes the dispatched hidden states after all-to-all communication.
def combine_preprocess(self, hidden_states):
# Unpermutation 2: Unsort tokens by local expert.
if self.num_local_experts > 1:
if self.drop_and_pad:
hidden_states = (
hidden_states.view(
self.num_local_experts,
self.tp_size * self.ep_size,
self.capacity,
*hidden_states.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
hidden_states, _ = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert.T.ravel(),
self.restore_output_by_local_experts,
fused=self.config.moe_permute_fusion,
)
This method retrieves the permuted hidden states by experts, calculates the number of tokens
per expert, and returns the processed data ready for expert processing.
"""
global_input_tokens, permuted_probs = (
self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states)
)
tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert()
return global_input_tokens, tokens_per_expert, permuted_probs
if self.tp_size > 1:
if self.output_splits_tp is None:
input_split_sizes = None
else:
input_split_sizes = self.output_splits_tp.tolist()
# The precision of TP reduce_scatter should be the same as the router_dtype
hidden_states = reduce_scatter_to_sequence_parallel_region(
hidden_states.to(self.probs.dtype),
group=self.tp_group,
input_split_sizes=input_split_sizes,
).to(hidden_states.dtype)
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Permutes tokens according to the routing map and dispatches them to experts.
return hidden_states
This method implements the token permutation process in three steps:
1. Preprocess the hidden states and routing information
2. Perform all-to-all communication to dispatch tokens
3. Post-process the dispatched tokens for expert processing
"""
hidden_states, _, _ = self.dispatch_preprocess(hidden_states, routing_map, probs)
hidden_states, _ = self.dispatch_all_to_all(
hidden_states, async_finish=False, allocate_on_comm_stream=False
)
global_input_tokens, tokens_per_expert, permuted_probs = self.dispatch_postprocess(
hidden_states
def combine_all_to_all(self, hidden_states):
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all(
self.ep_group, hidden_states, self.input_splits, self.output_splits
)
return permutated_local_input_tokens
return global_input_tokens, tokens_per_expert, permuted_probs
def combine_preprocess(self, hidden_states: torch.Tensor):
"""
Pre-processes the hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing
by using the communication manager's restoration function.
"""
hidden_states = self._comm_manager.get_restored_hidden_states_by_experts(hidden_states)
return hidden_states
def combine_postprocess(self, permutated_local_input_tokens):
if self.shared_experts is not None:
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
self.shared_experts.post_forward_comm()
def combine_all_to_all(
self,
hidden_states: torch.Tensor,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
):
"""
Performs all-to-all communication to combine tokens after expert processing.
"""
return self._comm_manager.combine(hidden_states, async_finish, allocate_on_comm_stream)
# Unpermutation 1: AlltoAll output to output
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
def combine_postprocess(self, hidden_states: torch.Tensor):
"""
Post-processes the combined hidden states after all-to-all communication.
# Reshape the output tensor
output = output.view(self.hidden_shape)
This method reshapes the combined hidden states to match the original input shape.
"""
return hidden_states.view(self.hidden_shape)
# Add shared experts output
if self.shared_experts is not None:
shared_expert_output = self.shared_experts.get_output()
output += shared_expert_output
return output
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverses the token permutation process to restore the original token order.
Reverse the token permutation to restore the original order.
This method performs the following steps:
1. Unsort tokens by local expert (if multiple local experts exist).
2. Perform expert parallel AlltoAll communication to restore the original order.
3. Unpermute tokens to restore the original order.
This method implements the token unpermutation process in three steps:
1. Pre-process the hidden states to restore their original ordering
2. Perform all-to-all communication to combine tokens
3. Post-process the combined tokens to match the original input shape
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher"
hidden_states = self.combine_preprocess(hidden_states)
hidden_states = self.combine_all_to_all(hidden_states, False, False)
hidden_states = self.combine_postprocess(hidden_states)
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
hidden_states = self.combine_preprocess(hidden_states)
permutated_local_input_tokens = self.combine_all_to_all(hidden_states)
output = self.combine_postprocess(permutated_local_input_tokens)
return hidden_states, None
return output, None
from megatron.core import parallel_state, tensor_parallel
from megatron.core.utils import (
deprecate_inference_params,
make_viewless_tensor,
)
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
class TransformerLayer(MegatronCoreTransformerLayer):
def _submodule_attn_router_forward(
self,
hidden_states,
attention_mask=None,
inference_params=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
sequence_len_offset=None,
state=None,
def _callable_wrapper(
self, is_forward, func, stream, event, *args, skip_detach=False, **kwargs
):
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
hidden_states, _ = self._forward_attention(
hidden_states=hidden_states,
torch.cuda.nvtx.range_push(func.__name__)
event.wait(stream)
with torch.cuda.stream(stream):
outputs = func(*args, **kwargs)
event.record(stream)
if skip_detach:
torch.cuda.nvtx.range_pop()
return outputs
detached_output_tensors = []
if not is_forward:
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
for tensor in outputs:
if tensor is None:
detached_output_tensors.append(None)
elif tensor.dtype.is_floating_point:
detached_output_tensors.append(tensor.detach().requires_grad_(True))
else:
detached_output_tensors.append(tensor.detach())
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
def _submodule_attention_forward(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
# todo
inference_context = deprecate_inference_params(inference_context, inference_params)
# Residual connection.
residual = hidden_states
# Optional Input Layer norm
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
return hidden_states
def _submodule_attention_router_compound_forward(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
hidden_states = self._submodule_attention_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
)
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
# Optional Layer norm post the cross-attention.
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
local_tokens, probs, tokens_per_expert = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, probs
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
pre_mlp_layernorm_output, probs, routing_map
)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert
)
return (local_tokens, probs, hidden_states, pre_mlp_layernorm_output, tokens_per_expert)
outputs = [
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
probs,
]
return tuple(outputs)
def _submodule_dispatch_forward(self, local_tokens, probs, state=None):
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
token_dispatcher = self.mlp.token_dispatcher
if self.is_deepep:
token_dispatcher._comm_manager.token_probs = probs
tokens_per_expert, global_input_tokens, global_probs = self.mlp.token_dispatcher.dispatch_all_to_all(
tokens_per_expert, permutated_local_input_tokens, permuted_probs
)
return [tokens_per_expert, global_input_tokens, global_probs]
return token_dispatcher.dispatch_all_to_all(local_tokens, probs)
def _submodule_dense_forward(self, hidden_states):
residual = hidden_states
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return output
def _submodule_moe_forward(self, dispatched_tokens, probs=None, state=None):
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_prob, hidden_states):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output = None
token_dispatcher = self.mlp.token_dispatcher
if self.is_deepep:
token_dispatcher._comm_manager.dispatched_probs = state.dispatched_probs
dispatched_tokens, tokens_per_expert, permuted_probs = (
token_dispatcher.dispatch_postprocess(dispatched_tokens)
)
else:
dispatched_tokens, permuted_probs = token_dispatcher.dispatch_postprocess(
dispatched_tokens, probs
)
tokens_per_expert = state.tokens_per_expert
expert_output, mlp_bias = self.mlp.experts(
dispatched_tokens, tokens_per_expert, permuted_probs
(dispatched_input, tokens_per_expert, permuted_probs) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_prob)
)
assert mlp_bias is None, f"Bias is not supported in {token_dispatcher.__class__.__name__}"
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(state.pre_mlp_layernorm_output)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(hidden_states)
return expert_output, shared_expert_output, mlp_bias
def _submodule_combine_forward(self, output, shared_expert_output=None, state=None):
residual = state.residual
token_dispatcher = self.mlp.token_dispatcher
output = token_dispatcher.combine_all_to_all(output)
output = token_dispatcher.combine_postprocess(output)
def _submodule_combine_forward(self, hidden_states):
return [self.mlp.token_dispatcher.combine_all_to_all(hidden_states)]
def _submodule_post_combine_forward(
self, expert_output, shared_expert_output, mlp_bias, residual
):
"""
Re-combines the expert outputs (and optional shared_expert_output) into the same order
as the original input tokens, applying any required bias.
"""
output = self.mlp.token_dispatcher.combine_postprocess(expert_output)
if shared_expert_output is not None:
output = output + shared_expert_output
mlp_output_with_bias = (output, None)
output += shared_expert_output
mlp_output_with_bias = (output, mlp_bias)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
......@@ -92,133 +213,141 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output
def _submodule_attn_router_dw(self):
def _submodule_attention_backward(
self, hidden_states, pre_mlp_layernorm_output, detached_inputs
):
pre_mlp_layernorm_output.backward(detached_inputs[1].grad)
hidden_states.backward(detached_inputs[0].grad)
def _submodule_attention_router_compound_backward(
self,
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
probs,
detached_inputs,
):
permutated_local_input_tokens.backward(detached_inputs[3].grad)
probs.backward(detached_inputs[4].grad)
# tokens_per_expert.backward(detached_inputs[2].grad)
pre_mlp_layernorm_output.backward(detached_inputs[1].grad)
hidden_states.backward(detached_inputs[0].grad)
def _submodule_dispatch_backward(self, global_input_tokens, detached_inputs):
global_input_tokens.backward(detached_inputs[0].grad)
def _submodule_dense_backward(self, output, detached_inputs):
output.backward(detached_inputs[0].grad)
def _submodule_moe_backward(
self, expert_output, shared_expert_output, mlp_bias, detached_inputs
):
expert_output.backward(detached_inputs[0].grad)
shared_expert_output.backward(detached_inputs[1].grad)
if mlp_bias is not None:
mlp_bias.backward(detached_inputs[2].grad)
def _submodule_combine_backward(self, hidden_states, detached_inputs):
hidden_states.backward(detached_inputs[0].grad)
def _submodule_post_combine_backward(self, output, output_grad):
output.backward(output_grad)
def _submodule_attention_router_compound_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_attention_router_compound_dw(self):
self.self_attention.backward_dw()
# raise NotImplementedError("Not implemented")
def _submodule_dispatch_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_mlp_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_mlp_dw(self):
self.mlp.backward_dw()
# raise NotImplementedError("Not implemented")
def _submodule_attn_router_postprocess(
self, node, local_tokens, probs, residual, pre_mlp_layernorm_output, tokens_per_expert
):
node.common_state.residual = node.detach(residual)
if self.mlp.use_shared_expert:
node.common_state.pre_mlp_layernorm_output = node.detach(pre_mlp_layernorm_output)
if not self.is_deepep:
node.common_state.tokens_per_expert = tokens_per_expert
return local_tokens, probs
def _submodule_dispatch_postprocess(self, node, dispatched_tokens, probs):
if self.is_deepep:
node.common_state.dispatched_probs = node.detach(probs)
return dispatched_tokens
else:
return dispatched_tokens, probs
def _submodule_mlp_postprocess(self, node, expert_output, shared_expert_output, mlp_bias):
assert mlp_bias is None
node.common_state.pre_mlp_layernorm_output = None
if shared_expert_output is None:
return expert_output
return expert_output, shared_expert_output
def _submodule_combine_postprocess(self, node, output):
cur_stream = torch.cuda.current_stream()
node.common_state.residual.record_stream(cur_stream)
node.common_state.residual = None
return output
def _submodule_combine_dgrad(self):
raise NotImplementedError("Not implemented")
def _submodule_attn_postprocess(self, node, hidden_states, context):
return hidden_states
def _submodule_identity_forward(self, *args):
return args
def _submodule_dense_postprocess(self, node, hidden_states):
return hidden_states
def _submodule_not_implemented(self, *args):
raise NotImplementedError("This callable is not implemented.")
def _submodule_identity_backward(self, *args):
pass
def get_submodule_callables(self, chunk_state):
def get_submodule_callables(self):
"""
The forward callables take 2 parts of inputs:
1. The ScheduleNode object.
2. The input tensors.
Returns a dictionary of submodule callables for the transformer layer.
"""
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.token_dispatcher import MoEFlexTokenDispatcher
self.is_moe = isinstance(self.mlp, MoELayer)
self.is_deepep = False
if self.is_moe:
self.is_deepep = isinstance(self.mlp.token_dispatcher, MoEFlexTokenDispatcher)
def get_func_with_default(func, default_func):
if self.is_moe:
if isinstance(self.mlp, MoELayer):
return func
return default_func
def callable_wrapper(forward_func, postprocess_func, node, *args):
state = getattr(node, 'common_state', None)
callable_outputs = forward_func(*args, state=state)
if isinstance(callable_outputs, tuple):
outputs = postprocess_func(node, *callable_outputs)
else:
outputs = postprocess_func(node, callable_outputs)
return outputs
attn_func = get_func_with_default(
self._submodule_attn_router_forward, self._forward_attention
)
def attn_wrapper(hidden_states, state=None):
return attn_func(
hidden_states=hidden_states,
attention_mask=chunk_state.attention_mask,
attention_bias=chunk_state.attention_bias,
inference_params=chunk_state.inference_params,
packed_seq_params=chunk_state.packed_seq_params,
sequence_len_offset=chunk_state.sequence_len_offset,
rotary_pos_emb=chunk_state.rotary_pos_emb,
rotary_pos_cos=chunk_state.rotary_pos_cos,
rotary_pos_sin=chunk_state.rotary_pos_sin,
state=state,
)
attn_postprocess_func = get_func_with_default(
self._submodule_attn_router_postprocess, self._submodule_attn_postprocess
attention_func = get_func_with_default(
self._submodule_attention_router_compound_forward, self._submodule_attention_forward
)
attention_backward_func = get_func_with_default(
self._submodule_attention_router_compound_backward, self._submodule_attention_backward
)
dispatch_func = get_func_with_default(
self._submodule_dispatch_forward, self._submodule_not_implemented
self._submodule_dispatch_forward, self._submodule_identity_forward
)
dispatch_postprocess_func = get_func_with_default(
self._submodule_dispatch_postprocess, self._submodule_not_implemented
dispatch_backward_func = get_func_with_default(
self._submodule_dispatch_backward, self._submodule_identity_backward
)
mlp_func = get_func_with_default(self._submodule_moe_forward, self._forward_mlp)
mlp_postprocess_func = get_func_with_default(
self._submodule_mlp_postprocess, self._submodule_dense_postprocess
mlp_func = get_func_with_default(self._submodule_moe_forward, self._submodule_dense_forward)
mlp_backward_func = get_func_with_default(
self._submodule_moe_backward, self._submodule_dense_backward
)
combine_func = get_func_with_default(
self._submodule_combine_forward, self._submodule_not_implemented
self._submodule_combine_forward, self._submodule_identity_forward
)
combine_postprocess_func = get_func_with_default(
self._submodule_combine_postprocess, self._submodule_not_implemented
combine_backward_func = get_func_with_default(
self._submodule_combine_backward, self._submodule_identity_backward
)
post_combine_func = get_func_with_default(
self._submodule_post_combine_forward, self._submodule_identity_forward
)
post_combine_backward_func = get_func_with_default(
self._submodule_post_combine_backward, self._submodule_identity_backward
)
attn_forward = partial(callable_wrapper, attn_wrapper, attn_postprocess_func)
dispatch_forward = partial(callable_wrapper, dispatch_func, dispatch_postprocess_func)
mlp_forward = partial(callable_wrapper, mlp_func, mlp_postprocess_func)
combine_forward = partial(callable_wrapper, combine_func, combine_postprocess_func)
callables = TransformerLayerSubmoduleCallables(
attention=SubmoduleCallables(forward=attn_forward, dw=self._submodule_attn_router_dw),
dispatch=SubmoduleCallables(forward=dispatch_forward),
mlp=SubmoduleCallables(forward=mlp_forward, dw=self._submodule_mlp_dw),
combine=SubmoduleCallables(forward=combine_forward),
is_moe=self.is_moe,
is_deepep=self.is_deepep,
attention=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, attention_func, skip_detach=True),
backward=partial(self._callable_wrapper, False, attention_backward_func),
# dgrad=partial(self._callable_wrapper, False,self._submodule_attention_router_compound_dgrad),
dw=partial(
self._callable_wrapper, False, self._submodule_attention_router_compound_dw
),
),
dispatch=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, dispatch_func),
backward=partial(self._callable_wrapper, False, dispatch_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_dispatch_dgrad),
),
mlp=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, mlp_func),
backward=partial(self._callable_wrapper, False, mlp_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_mlp_dgrad),
dw=partial(self._callable_wrapper, False, self._submodule_mlp_dw),
),
combine=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, combine_func),
backward=partial(self._callable_wrapper, False, combine_backward_func),
# dgrad=partial(self._callable_wrapper, False, self._submodule_combine_dgrad),
),
post_combine=SubmoduleCallables(
forward=partial(self._callable_wrapper, True, post_combine_func),
backward=partial(self._callable_wrapper, False, post_combine_backward_func),
),
)
return callables
\ No newline at end of file
from dataclasses import dataclass
from typing import Callable, Optional
@dataclass
class SubmoduleCallables:
"""
Holds references to forward, dgrad, and dw (weight-grad) callables
for a particular submodule.
"""
forward: Optional[Callable] = None
backward: Optional[Callable] = None
dgrad: Optional[Callable] = None
dw: Optional[Callable] = None
@dataclass
class TransformerLayerSubmoduleCallables:
"""
Collects the SubmoduleMethods for each of the submodules:
'attention', 'dispatch', 'mlp', 'combine'.
"""
attention: SubmoduleCallables
dispatch: SubmoduleCallables
mlp: SubmoduleCallables
combine: SubmoduleCallables
post_combine: SubmoduleCallables
def as_array(self):
return [self.attention, self.dispatch, self.mlp, self.combine, self.post_combine]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment