Commit 6a579b17 authored by dongcl's avatar dongcl
Browse files

rewrite combined_1f1b

parent e103a256
...@@ -12,7 +12,7 @@ from megatron.core.inference.contexts import BaseInferenceContext ...@@ -12,7 +12,7 @@ from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import transformer_layer from megatron.core.transformer import transformer_layer
from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.utils import deprecate_inference_params from megatron.core.utils import WrappedTensor, deprecate_inference_params
from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState
from dcu_megatron.core.pipeline_parallel.combined_1f1b import ( from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
...@@ -25,6 +25,12 @@ from dcu_megatron.core.pipeline_parallel.combined_1f1b import ( ...@@ -25,6 +25,12 @@ from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
def weak_method(method): def weak_method(method):
"""Creates a weak reference to a method to prevent circular references.
This function creates a weak reference to a method and returns a wrapper function
that calls the method when invoked. This helps prevent memory leaks from circular
references.
"""
method_ref = weakref.WeakMethod(method) method_ref = weakref.WeakMethod(method)
del method del method
...@@ -35,6 +41,40 @@ def weak_method(method): ...@@ -35,6 +41,40 @@ def weak_method(method):
return wrapped_func return wrapped_func
class MemoryStrategyRegistry:
"""Registry for memory management strategies based on node names.
This class centralizes the definition of which memory strategy
should be used for each type of node in the computation graph.
"""
@classmethod
def get_strategy_by_name(cls, name, is_moe, is_deepep):
"""Gets the appropriate memory strategy for a node based on its name and MoE status.
Args:
name: The name of the node, which determines which strategy to use.
is_moe: Whether the node is part of a Mixture of Experts model.
Returns:
The memory strategy to use for the node.
"""
strategies = {
"default": NoOpMemoryStrategy(),
"attn": NoOpMemoryStrategy(), # Attention nodes keep their inputs
"dispatch": (
FreeInputsMemoryStrategy() if not is_deepep else NoOpMemoryStrategy()
), # deepep dispatch inputs share same storage with moe inputs
"mlp": FreeInputsMemoryStrategy(), # MLP nodes free inputs after use
"combine": FreeInputsMemoryStrategy(), # Combine nodes free inputs after use
}
if is_moe:
return strategies.get(name, strategies["default"])
# For dense layers [attn, fake, mlp, fake], the inputs of mlp are required for backward
return NoOpMemoryStrategy()
class PreProcessNode(ScheduleNode): class PreProcessNode(ScheduleNode):
def __init__(self, gpt_model, model_chunk_state, event, stream): def __init__(self, gpt_model, model_chunk_state, event, stream):
...@@ -105,14 +145,25 @@ class PreProcessNode(ScheduleNode): ...@@ -105,14 +145,25 @@ class PreProcessNode(ScheduleNode):
and inference_context.is_static_batching() and inference_context.is_static_batching()
and not gpt_model.training and not gpt_model.training
): ):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor( sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * inference_context.current_batch_size, [inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32, dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
) )
else: else:
sequence_len_offset = None sequence_len_offset = None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if (
inference_context is not None
and not gpt_model.training
and not has_config_logger_enabled(gpt_model.config)
):
decoder_input = WrappedTensor(decoder_input)
# saved for later use # saved for later use
self.model_chunk_state.rotary_pos_emb = rotary_pos_emb self.model_chunk_state.rotary_pos_emb = rotary_pos_emb
self.model_chunk_state.rotary_pos_cos = rotary_pos_cos self.model_chunk_state.rotary_pos_cos = rotary_pos_cos
...@@ -157,12 +208,6 @@ class PostProcessNode(ScheduleNode): ...@@ -157,12 +208,6 @@ class PostProcessNode(ScheduleNode):
inp=hidden_states, requires_grad=True, keep_graph=True inp=hidden_states, requires_grad=True, keep_graph=True
) )
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss # logits and loss
output_weight = None output_weight = None
if gpt_model.share_embeddings_and_output_weights: if gpt_model.share_embeddings_and_output_weights:
...@@ -203,10 +248,18 @@ class PostProcessNode(ScheduleNode): ...@@ -203,10 +248,18 @@ class PostProcessNode(ScheduleNode):
if ( if (
not gpt_model.training not gpt_model.training
and inference_context is not None and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits and inference_context.materialize_only_last_token_logits
): ):
if inference_context.is_static_batching():
hidden_states = hidden_states[-1:, :, :] hidden_states = hidden_states[-1:, :, :]
else:
# Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden
# state ([B, H]) → unsqueeze back to [1, B, H]
# (so that the output layer, which expects S×B×H, receives only the final token)
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
logits, _ = gpt_model.output_layer( logits, _ = gpt_model.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
) )
...@@ -233,31 +286,55 @@ class PostProcessNode(ScheduleNode): ...@@ -233,31 +286,55 @@ class PostProcessNode(ScheduleNode):
class TransformerLayerNode(ScheduleNode): class TransformerLayerNode(ScheduleNode):
"""Base class for transformer layer computation nodes.
This class provides common functionality for different types of
transformer layer nodes (attention, MLP, etc.)
"""
def __init__(self, stream, event, state, callables, name="default"):
"""Initialize a transformer layer node.
Args:
stream (torch.cuda.Stream): CUDA stream for execution
event (torch.cuda.Event): Synchronization event
common_state (TransformerLayerState): State shared within a transformer layer
callables (Callable): The callables contain forward and dw function
it's the per_batch_state_context, o.w. nullcontext
name (str): Node name, also used to determine memory strategy
"""
# Get memory strategy based on node name
memory_strategy = MemoryStrategyRegistry.get_strategy_by_name(
name, callables.is_moe, callables.is_deepep
)
def __init__(self, chunk_state, common_state, layer, stream, event, free_inputs=False):
super().__init__( super().__init__(
weak_method(self.forward_impl), weak_method(self.forward_impl),
stream, stream,
event, event,
weak_method(self.backward_impl), weak_method(self.backward_impl),
free_inputs=free_inputs, memory_strategy=memory_strategy,
name=name,
) )
# layer state self.common_state = state
self.common_state = common_state self.callables = callables
# model chunk state
self.chunk_state = chunk_state
self.layer = layer
self.detached = tuple() self.detached = tuple()
self.before_detached = tuple() self.before_detached = tuple()
def detach(self, t): def detach(self, t):
"""Detaches a tensor and stores it for backward computation."""
detached = make_viewless(t).detach() detached = make_viewless(t).detach()
detached.requires_grad = t.requires_grad detached.requires_grad = t.requires_grad
self.before_detached = self.before_detached + (t,) self.before_detached = self.before_detached + (t,)
self.detached = self.detached + (detached,) self.detached = self.detached + (detached,)
return detached return detached
def forward_impl(self, *args):
"""Implements the forward pass for the transformer layer node."""
return self.callables.forward(self, *args)
def backward_impl(self, outputs, output_grad): def backward_impl(self, outputs, output_grad):
"""Implements the backward pass for the transformer layer node."""
detached_grad = tuple([e.grad for e in self.detached]) detached_grad = tuple([e.grad for e in self.detached])
grads = output_grad + detached_grad grads = output_grad + detached_grad
self.default_backward_func(outputs + self.before_detached, grads) self.default_backward_func(outputs + self.before_detached, grads)
...@@ -266,201 +343,84 @@ class TransformerLayerNode(ScheduleNode): ...@@ -266,201 +343,84 @@ class TransformerLayerNode(ScheduleNode):
# return grads for record stream # return grads for record stream
return grads return grads
class MoeAttnNode(TransformerLayerNode):
def forward_impl(self, hidden_states):
attention_mask = self.chunk_state.attention_mask
context = self.chunk_state.context
rotary_pos_emb = self.chunk_state.rotary_pos_emb
rotary_pos_cos = self.chunk_state.rotary_pos_cos
rotary_pos_sin = self.chunk_state.rotary_pos_sin
attention_bias = self.chunk_state.attention_bias
inference_context = self.chunk_state.inference_context
packed_seq_params = self.chunk_state.packed_seq_params
sequence_len_offset = self.chunk_state.sequence_len_offset
inference_params = self.chunk_state.inference_params
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
(
hidden_states,
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
) = self.layer._submodule_attention_router_compound_forward(
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
inference_params=inference_params,
)
self.common_state.tokens_per_expert = tokens_per_expert
# detached here
self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
return permutated_local_input_tokens, permuted_probs
def dw(self): def dw(self):
"""Computes the weight gradients for the transformer layer node."""
with torch.cuda.nvtx.range(f"{self.name} wgrad"): with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_attention_router_compound_dw() self.callables.dw()
class MoeDispatchNode(TransformerLayerNode): class TransformerLayerState:
"""State shared within a transformer layer.
def forward_impl(self, permutated_local_input_tokens, permuted_probs): This class holds state that is shared between different nodes
token_dispatcher = self.layer.mlp.token_dispatcher within a transformer layer.
with token_dispatcher.per_batch_state_context(self.common_state): """
tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(
self.common_state.tokens_per_expert, permutated_local_input_tokens, permuted_probs
)
# release tensor not used by backward
# inputs.untyped_storage().resize_(0)
self.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens, global_probs pass
class MoeMlPNode(TransformerLayerNode): class ModelChunkSate:
def forward_impl(self, global_input_tokens, global_probs): """State shared across a model chunk.
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
self.common_state.tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output
)
assert mlp_bias is None
# pre_mlp_layernorm_output used This class holds state that is shared between different components
self.common_state.pre_mlp_layernorm_output = None of a model chunk, such as input tensors, parameters, and configuration.
return expert_output, shared_expert_output """
def dw(self): pass
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_mlp_dw()
class MoeCombineNode(TransformerLayerNode): class TransformerLayerSchedulePlan:
def forward_impl(self, expert_output, shared_expert_output): """Schedule plan for a transformer layer.
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual = self.common_state.residual This class organizes the computation nodes for a transformer layer,
token_dispatcher = self.layer.mlp.token_dispatcher including attention, MLP, dispatch, and combine nodes.
with token_dispatcher.per_batch_state_context(self.common_state): """
permutated_local_input_tokens = token_dispatcher.combine_all_to_all(
expert_output def __init__(self, layer, event, chunk_state, comp_stream, com_stream):
) """Initializes a transformer layer schedule plan.
output = self.layer._submodule_post_combine_forward(
permutated_local_input_tokens, shared_expert_output, None, residual Args:
layer (TransformerLayer): The transformer layer to schedule.
event (torch.cuda.Event): CUDA event for synchronization.
chunk_state (ModelChunkState): State shared across the model chunk.
comp_stream (torch.cuda.Stream): CUDA stream for computation.
com_stream (torch.cuda.Stream): CUDA stream for communication.
"""
self.common_state = TransformerLayerState()
# get callables for transformer layer
attn_callable, dispatch_callable, mlp_callable, combine_callable = (
layer.get_submodule_callables(chunk_state).as_array()
) )
cur_stream = torch.cuda.current_stream()
self.common_state.residual.record_stream(cur_stream)
self.common_state.residual = None
return output
class DenseAttnNode(TransformerLayerNode): # Create nodes for different operations in the layer
# Each node type has a predefined name that determines its memory strategy
def forward_impl(self, hidden_states): self.attn = TransformerLayerNode(
attention_mask = self.chunk_state.attention_mask comp_stream, event, self.common_state, attn_callable, name="attn"
rotary_pos_emb = self.chunk_state.rotary_pos_emb
rotary_pos_cos = self.chunk_state.rotary_pos_cos
rotary_pos_sin = self.chunk_state.rotary_pos_sin
attention_bias = self.chunk_state.attention_bias
inference_context = self.chunk_state.inference_context
packed_seq_params = self.chunk_state.packed_seq_params
sequence_len_offset = self.chunk_state.sequence_len_offset
inference_params = self.chunk_state.inference_params
hidden_states = self.layer._submodule_attention_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
) )
return hidden_states self.mlp = TransformerLayerNode(
comp_stream, event, self.common_state, mlp_callable, name="mlp"
def dw(self): )
with torch.cuda.nvtx.range(f"{self.name} wgrad"): if attn_callable.is_moe:
self.layer._submodule_attention_dw() self.dispatch = TransformerLayerNode(
com_stream, event, self.common_state, dispatch_callable, name="dispatch"
)
class FakeScheduleNode: self.combine = TransformerLayerNode(
com_stream, event, self.common_state, combine_callable, name="combine"
def forward(self, inputs): )
return inputs else:
self.dispatch = FakeScheduleNode()
def backward(self, outgrads): self.combine = FakeScheduleNode()
return outgrads
class DenseMlpNode(TransformerLayerNode):
def forward_impl(self, hidden_states):
return self.layer._submodule_dense_forward(hidden_states)
def dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
self.layer._submodule_mlp_dw()
def build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream):
common_state = TransformerLayerState()
attn = DenseAttnNode(chunk_state, common_state, layer, comp_stream, event)
attn.name = "attn"
dispatch = FakeScheduleNode()
mlp = DenseMlpNode(chunk_state, common_state, layer, comp_stream, event)
mlp.name = "mlp"
combine = FakeScheduleNode()
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine)
def build_layer_schedule_plan(layer, event, chunk_state, comp_stream, com_stream):
if not isinstance(layer.mlp, MoELayer):
return build_non_moe_layer_plan(layer, event, chunk_state, comp_stream, com_stream)
common_state = TransformerLayerState()
attn = MoeAttnNode(chunk_state, common_state, layer, comp_stream, event)
attn.name = "attn"
dispatch = MoeDispatchNode(chunk_state, common_state, layer, com_stream, event, True)
dispatch.name = "dispatch"
mlp = MoeMlPNode(chunk_state, common_state, layer, comp_stream, event, True)
mlp.name = "mlp"
combine = MoeCombineNode(chunk_state, common_state, layer, com_stream, event, True)
combine.name = "combine"
return TransformerLayerSchedulePlan(attn, dispatch, mlp, combine)
class TransformerLayerState(MoEAlltoAllPerBatchState):
pass
class ModelChunkSate:
pass
class TransformerLayerSchedulePlan:
def __init__(self, attn, dispatch, mlp, combine): class ModelChunkSchedulePlan(AbstractSchedulePlan):
self.attn = attn """Schedule plan for a model chunk.
self.dispatch = dispatch
self.mlp = mlp
self.combine = combine
This class organizes the computation nodes for a model chunk,
including preprocessing, transformer layers, and postprocessing.
"""
class ModelChunkSchedulePlan(AbstractSchedulePlan):
def __init__(self): def __init__(self):
"""Initializes a model chunk schedule plan."""
super().__init__() super().__init__()
self._pre_process = None self._pre_process = None
self._post_process = None self._post_process = None
...@@ -481,7 +441,22 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan): ...@@ -481,7 +441,22 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
post_forward=None, post_forward=None,
post_backward=None, post_backward=None,
): ):
"""Schedules forward and backward passes for model chunks.
Args:
f_schedule_plan (ModelChunkSchedulePlan): Forward schedule plan.
b_schedule_plan (ModelChunkSchedulePlan): Backward schedule plan.
grad (Tensor): Gradient for backward computation.
f_context (VppContextManager or None): The VppContextManager for the forward pass.
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (Callable): Callback for preprocessing in forward pass.
pre_backward (Callable): Callback for preprocessing in backward pass.
post_forward (Callable): Callback for postprocessing in forward pass.
post_backward (Callable): Callback for postprocessing in backward pass.
Returns:
The output of the forward pass.
"""
return schedule_chunk_1f1b( return schedule_chunk_1f1b(
f_schedule_plan, f_schedule_plan,
b_schedule_plan, b_schedule_plan,
...@@ -496,48 +471,57 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan): ...@@ -496,48 +471,57 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
@property @property
def event(self): def event(self):
"""Gets the CUDA event for synchronization."""
return self._event return self._event
def record_current_stream(self): def record_current_stream(self):
"""Records the current CUDA stream in the event."""
stream = torch.cuda.current_stream() stream = torch.cuda.current_stream()
self.event.record(stream) self.event.record(stream)
def wait_current_stream(self): def wait_current_stream(self):
"""Waits for the event to complete on the current CUDA stream."""
stream = torch.cuda.current_stream() stream = torch.cuda.current_stream()
self.event.wait(stream) self.event.wait(stream)
@property @property
def pre_process(self): def pre_process(self):
"""Gets the preprocessing node."""
return self._pre_process return self._pre_process
@pre_process.setter @pre_process.setter
def pre_process(self, value): def pre_process(self, value):
"""Sets the preprocessing node."""
self._pre_process = value self._pre_process = value
@property @property
def post_process(self): def post_process(self):
"""Gets the postprocessing node."""
return self._post_process return self._post_process
@post_process.setter @post_process.setter
def post_process(self, value): def post_process(self, value):
"""Sets the postprocessing node."""
self._post_process = value self._post_process = value
def get_layer(self, i): def get_layer(self, i):
"""Gets the transformer layer at the specified index."""
assert i < self.num_layers() assert i < self.num_layers()
return self._transformer_layers[i] return self._transformer_layers[i]
def num_layers(self): def num_layers(self):
"""Gets the number of transformer layers."""
return len(self._transformer_layers) return len(self._transformer_layers)
def add_layer(self, layer): def add_layer(self, layer):
"""Adds a transformer layer to the schedule plan."""
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
@property @property
def state(self): def state(self):
"""Gets the model chunk state."""
return self._model_chunk_state return self._model_chunk_state
# F_DISPATCH_B_MLP_SYNC_EVENT = torch.cuda.Event()
F_DISPATCH_B_MLP_SYNC_EVENT = None
def schedule_layer_1f1b( def schedule_layer_1f1b(
f_layer, f_layer,
...@@ -550,6 +534,25 @@ def schedule_layer_1f1b( ...@@ -550,6 +534,25 @@ def schedule_layer_1f1b(
f_context=None, f_context=None,
b_context=None, b_context=None,
): ):
"""Schedule one-forward-one-backward operations for a single layer.
This function interleaves forward and backward operations to maximize
parallelism and efficiency.
Args:
f_layer (TransformerLayerSchedulePlan): Forward layer (for current microbatch)
b_layer (TransformerLayerSchedulePlan): Backward layer (for previous microbatch)
f_input (Tensor): Input for forward computation
b_grad (Tensor): Gradient for backward computation
pre_forward (Callable): Callback to get forward input if not provided
pre_backward (Callable): Callback to get backward gradient if not provided
pre_backward_dw (Callable): Callback for weight gradient computation
f_context (VppContextManager or None): The VppContextManager for the forward pass.
b_context (VppContextManager or None): The VppContextManager for the backward pass
Returns:
Functions or values for next iteration's computation
"""
f_context = f_context if f_context is not None else contextlib.nullcontext() f_context = f_context if f_context is not None else contextlib.nullcontext()
b_context = b_context if b_context is not None else contextlib.nullcontext() b_context = b_context if b_context is not None else contextlib.nullcontext()
...@@ -577,17 +580,13 @@ def schedule_layer_1f1b( ...@@ -577,17 +580,13 @@ def schedule_layer_1f1b(
with f_context: with f_context:
f_input = f_layer.attn.forward(f_input) f_input = f_layer.attn.forward(f_input)
f_dispatch_b_mlp_sync_event = None
if f_layer is not None and b_layer is not None:
f_dispatch_b_mlp_sync_event = F_DISPATCH_B_MLP_SYNC_EVENT
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event) f_input = f_layer.dispatch.forward(f_input)
if b_layer is not None: if b_layer is not None:
with b_context: with b_context:
b_grad = b_layer.mlp.backward(b_grad, stream_wait_event=f_dispatch_b_mlp_sync_event) b_grad = b_layer.mlp.backward(b_grad)
b_grad = b_layer.dispatch.backward(b_grad) b_grad = b_layer.dispatch.backward(b_grad)
b_layer.mlp.dw() b_layer.mlp.dw()
...@@ -629,13 +628,32 @@ def schedule_chunk_1f1b( ...@@ -629,13 +628,32 @@ def schedule_chunk_1f1b(
post_forward=None, post_forward=None,
post_backward=None, post_backward=None,
): ):
"""Schedules one-forward-one-backward operations for a model chunk.
This function interleaves forward and backward operations across multiple layers
to maximize parallelism and efficiency.
Args:
f_schedule_plan: Forward schedule plan.
b_schedule_plan: Backward schedule plan.
grad: Gradient for backward computation.
f_context: Context for forward computation.
b_context: Context for backward computation.
pre_forward: Callback for preprocessing in forward pass.
pre_backward: Callback for preprocessing in backward pass.
post_forward: Callback for postprocessing in forward pass.
post_backward: Callback for postprocessing in backward pass.
Returns:
The output of the forward pass.
"""
f_context = f_context if f_context is not None else contextlib.nullcontext() f_context = f_context if f_context is not None else contextlib.nullcontext()
b_context = b_context if b_context is not None else contextlib.nullcontext() b_context = b_context if b_context is not None else contextlib.nullcontext()
if f_schedule_plan: if f_schedule_plan:
# pp output send/receive sync # pp output send/receive sync
if pre_forward is not None: if pre_forward is not None:
with f_context: with f_context: # virtual pipeline parallel context
pre_forward() pre_forward()
f_schedule_plan.record_current_stream() f_schedule_plan.record_current_stream()
...@@ -655,14 +673,14 @@ def schedule_chunk_1f1b( ...@@ -655,14 +673,14 @@ def schedule_chunk_1f1b(
if b_schedule_plan is not None: if b_schedule_plan is not None:
assert grad is not None assert grad is not None
if b_schedule_plan.post_process is not None: if b_schedule_plan.post_process is not None:
with b_context: with b_context: # virtual pipeline parallel context
tmp = b_schedule_plan.post_process.backward(grad) tmp = b_schedule_plan.post_process.backward(grad)
if pre_backward is not None: if pre_backward is not None:
# pp grad send receive sync here, safe for now, maybe not safe in the future # pp grad send receive sync here, safe for now, maybe not safe in the future
with torch.cuda.stream(get_com_stream()): with torch.cuda.stream(get_com_stream()):
b_schedule_plan.wait_current_stream() b_schedule_plan.wait_current_stream()
with b_context: with b_context: # virtual pipeline parallel context
pre_backward() pre_backward()
b_schedule_plan.record_current_stream() b_schedule_plan.record_current_stream()
...@@ -693,32 +711,40 @@ def schedule_chunk_1f1b( ...@@ -693,32 +711,40 @@ def schedule_chunk_1f1b(
# tail forward # tail forward
f_input = layer_pre_forward() f_input = layer_pre_forward()
del layer_pre_forward del layer_pre_forward
# tail backward # tail backward
grad = layer_pre_backward() grad = layer_pre_backward()
del layer_pre_backward del layer_pre_backward
with b_context: with b_context:
for i in range(overlaped_layers, b_num_layers): for i in range(overlaped_layers, b_num_layers):
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i) b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b") torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
_, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad) tmp, grad, _ = schedule_layer_1f1b(None, b_layer, b_grad=grad)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# if b_schedule_plan is not None:
# b_schedule_plan.pre_process.backward(grad)
# # tail forward
# f_input = layer_pre_forward()
# del layer_pre_forward
with f_context: with f_context:
for i in range(overlaped_layers, f_num_layers): for i in range(overlaped_layers, f_num_layers):
f_layer = f_schedule_plan.get_layer(i) f_layer = f_schedule_plan.get_layer(i)
torch.cuda.nvtx.range_push(f"layer_{i}f") torch.cuda.nvtx.range_push(f"layer_{i}f")
f_input, _, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input) f_input, tmp, _ = schedule_layer_1f1b(f_layer, None, f_input=f_input)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
# f_input = f_schedule_plan.post_process.forward(f_input)
# output pp send receive, overlapped with attn backward # output pp send receive, overlapped with attn backward
if f_schedule_plan is not None and post_forward is not None: if f_schedule_plan is not None and post_forward is not None:
with f_context: with f_context:
f_schedule_plan.wait_current_stream() f_schedule_plan.wait_current_stream()
post_forward(f_input) post_forward(f_input)
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch # pp grad send / receive, overlapped with attn dw of cur micro-batch
# and forward attn of next micro-batch
if b_schedule_plan is not None and post_backward is not None: if b_schedule_plan is not None and post_backward is not None:
with b_context: with b_context:
b_schedule_plan.wait_current_stream() b_schedule_plan.wait_current_stream()
...@@ -750,14 +776,31 @@ def build_model_chunk_schedule_plan( ...@@ -750,14 +776,31 @@ def build_model_chunk_schedule_plan(
attention_mask: Tensor, attention_mask: Tensor,
decoder_input: Tensor = None, decoder_input: Tensor = None,
labels: Tensor = None, labels: Tensor = None,
inference_context: BaseInferenceContext = None, inference_params=None,
packed_seq_params: PackedSeqParams = None, packed_seq_params=None,
extra_block_kwargs: dict = None, extra_block_kwargs=None,
runtime_gather_output: Optional[bool] = None, runtime_gather_output: Optional[bool] = None,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None
): ):
"""Builds a schedule plan for a model chunk.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
Args:
model: The model to build a schedule plan for.
input_ids: Input token IDs.
position_ids: Position IDs.
attention_mask: Attention mask.
decoder_input: Decoder input tensor.
labels: Labels for loss computation.
inference_params: Parameters for inference.
packed_seq_params: Parameters for packed sequences.
extra_block_kwargs: Additional keyword arguments for blocks.
runtime_gather_output: Whether to gather output at runtime.
Returns:
The model chunk schedule plan.
"""
comp_stream = get_comp_stream() comp_stream = get_comp_stream()
com_stream = get_com_stream() com_stream = get_com_stream()
model_chunk_schedule_plan = ModelChunkSchedulePlan() model_chunk_schedule_plan = ModelChunkSchedulePlan()
...@@ -769,28 +812,23 @@ def build_model_chunk_schedule_plan( ...@@ -769,28 +812,23 @@ def build_model_chunk_schedule_plan(
state.attention_mask = attention_mask state.attention_mask = attention_mask
state.decoder_input = decoder_input state.decoder_input = decoder_input
state.labels = labels state.labels = labels
state.inference_context = inference_context state.inference_params = inference_params
state.packed_seq_params = packed_seq_params state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output state.runtime_gather_output = runtime_gather_output
state.inference_params = inference_params
state.loss_mask = loss_mask
state.context = None state.context = None
state.context_mask = None state.context_mask = None
state.attention_bias = None state.attention_bias = None
# build preprocess # build preprocess
model_chunk_schedule_plan.pre_process = PreProcessNode(model, state, event, comp_stream) model_chunk_schedule_plan.pre_process = PreProcessNode(model, state, event, comp_stream)
model_chunk_schedule_plan.pre_process.name = "pre_process"
# build for layers # build for layers
for layer_idx in range(model.decoder.num_layers_per_pipeline_rank): for layer_idx in range(model.decoder.num_layers_per_pipeline_rank):
layer = model.decoder._get_layer(layer_idx) layer = model.decoder._get_layer(layer_idx)
layer_plan = build_layer_schedule_plan(layer, event, state, comp_stream, com_stream) layer_plan = TransformerLayerSchedulePlan(layer, event, state, comp_stream, com_stream)
model_chunk_schedule_plan.add_layer(layer_plan) model_chunk_schedule_plan.add_layer(layer_plan)
# build post process # build post process
if model.post_process: if model.post_process:
model_chunk_schedule_plan.post_process = PostProcessNode(model, state, event, comp_stream) model_chunk_schedule_plan.post_process = PostProcessNode(model, state, event, comp_stream)
model_chunk_schedule_plan.post_process.name = "post_process"
return model_chunk_schedule_plan return model_chunk_schedule_plan
...@@ -10,7 +10,7 @@ from torch.autograd.variable import Variable ...@@ -10,7 +10,7 @@ from torch.autograd.variable import Variable
from megatron.training import get_args from megatron.training import get_args
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel from megatron.core.distributed import DistributedDataParallel
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.module import Float16Module from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
...@@ -20,13 +20,14 @@ from dcu_megatron.core.parallel_state import get_dualpipe_chunk ...@@ -20,13 +20,14 @@ from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def make_viewless(e): def make_viewless(e):
"""make_viewless util func""" """Make_viewless util func"""
e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True) e = make_viewless_tensor(inp=e, requires_grad=e.requires_grad, keep_graph=True)
return e return e
@contextmanager @contextmanager
def stream_acquire_context(stream, event): def stream_acquire_context(stream, event):
"""Stream acquire context"""
event.wait(stream) event.wait(stream)
try: try:
yield yield
...@@ -34,8 +35,29 @@ def stream_acquire_context(stream, event): ...@@ -34,8 +35,29 @@ def stream_acquire_context(stream, event):
event.record(stream) event.record(stream)
class FakeScheduleNode:
"""A placeholder node in the computation graph that simply passes through inputs and outputs.
This class is used as a no-op node in the scheduling system when a real computation node
is not needed but the interface must be maintained. It simply returns its inputs unchanged
in both forward and backward passes.
"""
def forward(self, inputs):
"""Passes through inputs unchanged in the forward pass."""
return inputs
def backward(self, outgrads):
"""Passes through gradients unchanged in the backward pass."""
return outgrads
class ScheduleNode: class ScheduleNode:
"""base node for fine-grained schedule""" """Base node for fine-grained scheduling.
This class represents a computational node in the pipeline schedule.
It handles the execution of forward and backward operations on a stream.
"""
def __init__( def __init__(
self, self,
...@@ -43,24 +65,30 @@ class ScheduleNode: ...@@ -43,24 +65,30 @@ class ScheduleNode:
stream, stream,
event, event,
backward_func=None, backward_func=None,
free_inputs=False, memory_strategy=None,
name="schedule_node", name="schedule_node",
): ):
"""Initialize a schedule node.
Args:
forward_func (callable): Function to execute during forward pass
stream (torch.cuda.Stream): CUDA stream for computation
event (torch.cuda.Event): Event for synchronization
backward_func (callable, optional): Function for backward pass
memory_strategy (MemoryManagementStrategy, optional): Strategy for memory management
name (str): Name of the node for debugging
"""
self.name = name self.name = name
self.forward_func = forward_func self.forward_func = forward_func
self.backward_func = backward_func self.backward_func = backward_func if backward_func else self.default_backward_func
self.stream = stream self.stream = stream
self.event = event self.event = event
self.free_inputs = free_inputs self.memory_strategy = memory_strategy or NoOpMemoryStrategy()
self.inputs = None self.inputs = None
self.outputs = None self.outputs = None
def default_backward_func(self, outputs, output_grad): def default_backward_func(self, outputs, output_grad):
# Handle scalar output """Default backward function"""
if output_grad is None:
assert outputs.numel() == 1, "implicit grad requires scalar output."
output_grad = torch.ones_like(outputs, memory_format=torch.preserve_format)
Variable._execution_engine.run_backward( Variable._execution_engine.run_backward(
tensors=outputs, tensors=outputs,
grad_tensors=output_grad, grad_tensors=output_grad,
...@@ -72,20 +100,16 @@ class ScheduleNode: ...@@ -72,20 +100,16 @@ class ScheduleNode:
) )
return output_grad return output_grad
def forward(self, inputs=(), stream_wait_event=None, stream_record_event=None): def forward(self, inputs=()):
"""schedule node forward""" """Schedule node forward"""
if not isinstance(inputs, tuple): if not isinstance(inputs, tuple):
inputs = (inputs,) inputs = (inputs,)
return self._forward(*inputs, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event) return self._forward(*inputs)
def _forward(self, *inputs, stream_wait_event=None, stream_record_event=None): def _forward(self, *inputs):
with stream_acquire_context(self.stream, self.event): with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} forward") torch.cuda.nvtx.range_push(f"{self.name} forward")
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs] self.inputs = [make_viewless(e).detach() if e is not None else None for e in inputs]
for i, input in enumerate(self.inputs): for i, input in enumerate(self.inputs):
if input is not None: if input is not None:
...@@ -100,50 +124,35 @@ class ScheduleNode: ...@@ -100,50 +124,35 @@ class ScheduleNode:
data = tuple([make_viewless(e) if isinstance(e, Tensor) else e for e in data]) data = tuple([make_viewless(e) if isinstance(e, Tensor) else e for e in data])
self.output = data self.output = data
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
if self.free_inputs: # Handle inputs using the memory strategy
for input in inputs: self.memory_strategy.handle_inputs(inputs, self.stream)
input.record_stream(self.stream)
input.untyped_storage().resize_(0)
return self.output return self.output
def get_output(self): def get_output(self):
"""get the forward output""" """Get the forward output"""
return self.output return self.output
def backward(self, output_grad, stream_wait_event=None, stream_record_event=None): def backward(self, output_grad):
"""schedule node backward""" """Schedule node backward"""
if not isinstance(output_grad, tuple): if not isinstance(output_grad, tuple):
output_grad = (output_grad,) output_grad = (output_grad,)
return self._backward(*output_grad, stream_wait_event=stream_wait_event, stream_record_event=stream_record_event) return self._backward(*output_grad)
def _backward(self, *output_grad, stream_wait_event=None, stream_record_event=None): def _backward(self, *output_grad):
with stream_acquire_context(self.stream, self.event): with stream_acquire_context(self.stream, self.event):
torch.cuda.nvtx.range_push(f"{self.name} backward") torch.cuda.nvtx.range_push(f"{self.name} backward")
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
if stream_wait_event is not None:
stream_wait_event.wait(self.stream)
outputs = self.output outputs = self.output
if not isinstance(outputs, tuple): if not isinstance(outputs, tuple):
outputs = (outputs,) outputs = (outputs,)
assert len(outputs) == len( assert len(outputs) == len(output_grad), (
output_grad f"{len(outputs)} of {type(outputs[0])} is not equal to "
), f"{len(outputs)} of {type(outputs[0])} vs {len(output_grad)} of {type(output_grad[0])}" f"{len(output_grad)} of {type(output_grad[0])}"
if self.backward_func is not None: )
output_grad = self.backward_func(outputs, output_grad) output_grad = self.backward_func(outputs, output_grad)
else:
output_grad = self.default_backward_func(outputs, output_grad)
if stream_record_event is not None:
stream_record_event.record(self.stream)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
# output_grad maybe from another stream # output_grad maybe from another stream
...@@ -153,7 +162,7 @@ class ScheduleNode: ...@@ -153,7 +162,7 @@ class ScheduleNode:
return self.get_grad() return self.get_grad()
def get_grad(self): def get_grad(self):
"""get the grad of inputs""" """Get the grad of inputs"""
grad = tuple([e.grad if e is not None else None for e in self.inputs]) grad = tuple([e.grad if e is not None else None for e in self.inputs])
# clear state # clear state
self.inputs = None self.inputs = None
...@@ -165,7 +174,7 @@ class ScheduleNode: ...@@ -165,7 +174,7 @@ class ScheduleNode:
class AbstractSchedulePlan(ABC): class AbstractSchedulePlan(ABC):
"""to use combined 1f1b, model must implement build_schedule_plan while take the same """To use combined 1f1b, model must implement build_schedule_plan while take the same
signature as model forward but return an instance of AbstractSchedulePlan""" signature as model forward but return an instance of AbstractSchedulePlan"""
@classmethod @classmethod
...@@ -197,7 +206,29 @@ def schedule_chunk_1f1b( ...@@ -197,7 +206,29 @@ def schedule_chunk_1f1b(
post_forward=None, post_forward=None,
post_backward=None, post_backward=None,
): ):
"""model level 1f1b fine-grained schedule""" """Model level 1f1b fine-grained schedule
This function schedules the forward and backward passes for a chunk of the model.
It takes in the forward schedule plan, backward schedule plan, gradient, and optional
context managers for the forward and backward passes.
Args:
f_schedule_plan (subclass of AbstractSchedulePlan): The forward schedule plan
b_schedule_plan (subclass of AbstractSchedulePlan): The backward schedule plan
grad (Tensor or None): The gradient of the loss function
f_context (VppContextManager or None): The VppContextManager for the forward pass
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (callable or None): The function to call before the forward pass
pre_backward (callable or None): The function to call before the backward pass
post_forward (callable or None): The function to call after the forward pass
post_backward (callable or None): The function to call after the backward pass
Returns:
The output of the forward pass
"""
# Calls fine_grained_schedule.py::ModelChunkSchedulePlan.forward_backward(),
# which calls fine_grained_schedule.py::schedule_chunk_1f1b()
return type(f_schedule_plan or b_schedule_plan).forward_backward( return type(f_schedule_plan or b_schedule_plan).forward_backward(
f_schedule_plan, f_schedule_plan,
b_schedule_plan, b_schedule_plan,
...@@ -216,7 +247,7 @@ _COM_STREAM = None ...@@ -216,7 +247,7 @@ _COM_STREAM = None
def set_streams(comp_stream=None, com_stream=None): def set_streams(comp_stream=None, com_stream=None):
"""set the streams for communication and computation""" """Set the streams for communication and computation"""
global _COMP_STREAM global _COMP_STREAM
global _COM_STREAM global _COM_STREAM
if _COMP_STREAM is not None: if _COMP_STREAM is not None:
...@@ -234,19 +265,19 @@ def set_streams(comp_stream=None, com_stream=None): ...@@ -234,19 +265,19 @@ def set_streams(comp_stream=None, com_stream=None):
def get_comp_stream(): def get_comp_stream():
"""get the stream for computation""" """Get the stream for computation"""
global _COMP_STREAM global _COMP_STREAM
return _COMP_STREAM return _COMP_STREAM
def get_com_stream(): def get_com_stream():
"""get the stream for communication""" """Get the stream for communication"""
global _COM_STREAM global _COM_STREAM
return _COM_STREAM return _COM_STREAM
class VppContextManager: class VppContextManager:
"""a reusable context manager for switch vpp stage""" """A reusable context manager for switch vpp stage"""
def __init__(self, vpp_rank): def __init__(self, vpp_rank):
self.vpp_rank = vpp_rank self.vpp_rank = vpp_rank
...@@ -353,9 +384,17 @@ def forward_backward_step( ...@@ -353,9 +384,17 @@ def forward_backward_step(
Tensor or list[Tensor]: The output object(s) from the forward step. Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens. Tensor: The number of tokens.
""" """
assert (
checkpoint_activations_microbatch is None
), "checkpoint_activations_microbatch is not supported for combined_1f1b"
if config.combined_1f1b_recipe != "ep_a2a":
raise NotImplementedError(
f"combined_1f1b_recipe {config.combined_1f1b_recipe} not supported yet"
)
from megatron.core.pipeline_parallel.schedules import set_current_microbatch from megatron.core.pipeline_parallel.schedules import set_current_microbatch
if config.timers is not None: if f_model is not None and config.timers is not None:
config.timers('forward-compute', log_level=2).start() config.timers('forward-compute', log_level=2).start()
if config.enable_autocast: if config.enable_autocast:
...@@ -364,6 +403,8 @@ def forward_backward_step( ...@@ -364,6 +403,8 @@ def forward_backward_step(
context_manager = contextlib.nullcontext() context_manager = contextlib.nullcontext()
# forward preprocess # forward preprocess
unwrap_output_tensor = False
f_schedule_plan = None
if f_model is not None: if f_model is not None:
with f_context: with f_context:
if is_first_microbatch and hasattr(f_model, 'set_is_first_microbatch'): if is_first_microbatch and hasattr(f_model, 'set_is_first_microbatch'):
...@@ -371,7 +412,6 @@ def forward_backward_step( ...@@ -371,7 +412,6 @@ def forward_backward_step(
if current_microbatch is not None: if current_microbatch is not None:
set_current_microbatch(f_model, current_microbatch) set_current_microbatch(f_model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list): if not isinstance(input_tensor, list):
input_tensor = [input_tensor] input_tensor = [input_tensor]
unwrap_output_tensor = True unwrap_output_tensor = True
...@@ -381,20 +421,20 @@ def forward_backward_step( ...@@ -381,20 +421,20 @@ def forward_backward_step(
with context_manager: with context_manager:
if checkpoint_activations_microbatch is None: if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, f_model) f_schedule_plan, loss_func = forward_step_func(data_iterator, f_model)
else: else:
output_tensor, loss_func = forward_step_func( f_schedule_plan, loss_func = forward_step_func(
data_iterator, f_model, checkpoint_activations_microbatch data_iterator, f_model, checkpoint_activations_microbatch
) )
assert isinstance( assert isinstance(
output_tensor, AbstractSchedulePlan f_schedule_plan, AbstractSchedulePlan
), "first output of forward_step_func must be one instance of AbstractSchedulePlan" ), "first output of forward_step_func must be one instance of AbstractSchedulePlan"
# backward preprocess # backward preprocess
unwrap_input_tensor_grad = False
b_schedule_plan = None b_schedule_plan = None
if b_model is not None: if b_model is not None:
# Retain the grad on the input_tensor. # Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(b_input_tensor, list): if not isinstance(b_input_tensor, list):
b_input_tensor = [b_input_tensor] b_input_tensor = [b_input_tensor]
unwrap_input_tensor_grad = True unwrap_input_tensor_grad = True
...@@ -418,9 +458,8 @@ def forward_backward_step( ...@@ -418,9 +458,8 @@ def forward_backward_step(
torch.autograd.backward(b_output_tensor[0], grad_tensors=b_output_tensor_grad[0]) torch.autograd.backward(b_output_tensor[0], grad_tensors=b_output_tensor_grad[0])
b_output_tensor_grad[0] = loss_node.get_grad() b_output_tensor_grad[0] = loss_node.get_grad()
f_schedule_plan = output_tensor if f_model else None
grad = b_output_tensor_grad[0] if b_model else None grad = b_output_tensor_grad[0] if b_model else None
with context_manager: with context_manager: # autocast context
# schedule forward and backward # schedule forward and backward
output_tensor = schedule_chunk_1f1b( output_tensor = schedule_chunk_1f1b(
f_schedule_plan, f_schedule_plan,
...@@ -436,7 +475,7 @@ def forward_backward_step( ...@@ -436,7 +475,7 @@ def forward_backward_step(
# forward post process # forward post process
num_tokens = None num_tokens = None
if f_model: if f_model is not None:
with f_context: with f_context:
model_vp_stage = getattr(f_model, "vp_stage", None) model_vp_stage = getattr(f_model, "vp_stage", None)
if vp_stage is not None and model_vp_stage is not None: if vp_stage is not None and model_vp_stage is not None:
...@@ -535,7 +574,18 @@ def forward_backward_step( ...@@ -535,7 +574,18 @@ def forward_backward_step(
return output_tensor, num_tokens, input_tensor_grad return output_tensor, num_tokens, input_tensor_grad
def get_default_cls_for_unwrap(): def get_default_cls_for_unwrap():
"""Returns the default classes to unwrap from a model.
This function provides a tuple of classes that should be unwrapped from a model
to access the underlying GPTModel instance. It includes DistributedDataParallel
and Float16Module by default, and also attempts to include LegacyFloat16Module
if available for backward compatibility.
Returns:
tuple: A tuple of classes to unwrap from a model.
"""
cls = (DistributedDataParallel, Float16Module) cls = (DistributedDataParallel, Float16Module)
try: try:
# legacy should not be used in core, but for backward compatibility, we support it here # legacy should not be used in core, but for backward compatibility, we support it here
...@@ -547,7 +597,9 @@ def get_default_cls_for_unwrap(): ...@@ -547,7 +597,9 @@ def get_default_cls_for_unwrap():
def unwrap_model(model, module_instances=get_default_cls_for_unwrap()): def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
"""unwrap_model DistributedDataParallel and Float16Module wrapped model""" """Unwrap_model DistributedDataParallel and Float16Module wrapped model
to return GPTModel instance
"""
return_list = True return_list = True
if not isinstance(model, list): if not isinstance(model, list):
model = [model] model = [model]
...@@ -556,19 +608,80 @@ def unwrap_model(model, module_instances=get_default_cls_for_unwrap()): ...@@ -556,19 +608,80 @@ def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
for model_module in model: for model_module in model:
while isinstance(model_module, module_instances): while isinstance(model_module, module_instances):
model_module = model_module.module model_module = model_module.module
assert isinstance(
model_module, GPTModel
), "The final unwrapped model must be a GPTModel instance"
unwrapped_model.append(model_module) unwrapped_model.append(model_module)
if not return_list: if not return_list:
return unwrapped_model[0] return unwrapped_model[0]
return unwrapped_model return unwrapped_model
def wrap_forward_func(config, forward_step_func): def wrap_forward_func(forward_step_func):
"""wrap the input to forward_step_func, to make forward_step_func return schedule plan""" """Wrap the input to forward_step_func.
The wrapped function will return forward_schedule_plan and the loss_function.
"""
def wrapped_func(data_iterator, model): def wrapped_func(data_iterator, model):
# Model is unwrapped to get GPTModel instance.
# GPTModel.build_schedule_plan(model_forward_inputs) is called in the forward_step.
# The return value becomes (forward_schedule_plan, loss_function),
# which is used to be (forward_output_tensor, loss_function).
return forward_step_func(data_iterator, unwrap_model(model).build_schedule_plan) return forward_step_func(data_iterator, unwrap_model(model).build_schedule_plan)
if config.combined_1f1b and config.combined_1f1b_recipe == "ep_a2a":
return wrapped_func return wrapped_func
else:
return forward_step_func
class MemoryManagementStrategy:
"""Base class for memory management strategies.
Different memory management strategies can be implemented by subclassing this class.
These strategies control how tensors are handled in memory during the computation.
"""
def handle_inputs(self, inputs, stream):
"""Process input tensors after computation.
Args:
inputs (tuple): Input tensors that have been used
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
def handle_outputs(self, outputs, stream):
"""Process output tensors after computation.
Args:
outputs (tuple): Output tensors produced by the computation
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
class NoOpMemoryStrategy(MemoryManagementStrategy):
"""Strategy that performs no memory management operations.
This is the default strategy - it doesn't free any memory.
"""
pass
class FreeInputsMemoryStrategy(MemoryManagementStrategy):
"""Strategy that immediately frees input tensors after they are used.
This strategy is useful for nodes where inputs are no longer needed
after computation, helping to reduce memory usage.
"""
def handle_inputs(self, inputs, stream):
"""Free input tensors by resizing their storage to zero.
Args:
inputs (tuple): Input tensors to be freed
stream (torch.cuda.Stream): Current CUDA stream
"""
for input in inputs:
if input is not None:
input.record_stream(stream)
input.untyped_storage().resize_(0)
...@@ -94,7 +94,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -94,7 +94,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.collect_per_batch_state(state) self.collect_per_batch_state(state)
self.apply_per_batch_state(origin_state) self.apply_per_batch_state(origin_state)
def meta_prepare( def dispatch_preprocess(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
): ):
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape
...@@ -112,9 +112,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -112,9 +112,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self.routing_map = pad_routing_map(self.routing_map, pad_multiple) self.routing_map = pad_routing_map(self.routing_map, pad_multiple)
tokens_per_expert = self.preprocess(self.routing_map) tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert
def dispatch_preprocess(self, hidden_states: torch.Tensor, routing_map: torch.Tensor, tokens_per_expert: torch.Tensor):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
if self.shared_experts is not None: if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape)) self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
...@@ -235,8 +232,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher): ...@@ -235,8 +232,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
""" """
# Preprocess: Get the metadata for communication, permutation and computation operations. # Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input # Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map) tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs) tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
......
...@@ -193,7 +193,7 @@ def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[i ...@@ -193,7 +193,7 @@ def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[i
class TransformerLayer(MegatronCoreTransformerLayer): class TransformerLayer(MegatronCoreTransformerLayer):
def forward( def _submodule_attn_router_forward(
self, self,
hidden_states: Tensor, hidden_states: Tensor,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
...@@ -209,12 +209,10 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -209,12 +209,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
*, *,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
): ):
"""
if ( Performs a combined forward pass that includes self-attention and MLP routing logic.
not isinstance(self.mlp, MoELayer) """
or not isinstance(self.mlp.token_dispatcher, MoEAlltoAllTokenDispatcher) pre_mlp_layernorm_output, residual, context = self._forward_attention(
):
return super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
context=context, context=context,
...@@ -229,180 +227,80 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -229,180 +227,80 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params=inference_params, inference_params=inference_params,
) )
( probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
hidden_states, tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, pre_mlp_layernorm_output, probs, routing_map
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
) )
(tokens_per_expert, global_input_tokens, global_probs) = self._submodule_dispatch_forward( return (tokens_per_expert, permutated_local_input_tokens, permuted_probs, pre_mlp_layernorm_output, residual, context)
def _submodule_attn_router_postprocess(
self,
node,
tokens_per_expert, tokens_per_expert,
permutated_local_input_tokens, permutated_local_input_tokens,
permuted_probs, permuted_probs,
) pre_mlp_layernorm_output,
residual,
(expert_output, shared_expert_output, mlp_bias) = self._submodule_moe_forward( context,
tokens_per_expert,
global_input_tokens,
global_probs,
pre_mlp_layernorm_output
)
expert_output = self._submodule_combine_forward(expert_output)[0]
output = self._submodule_post_combine_forward(
expert_output,
shared_expert_output,
mlp_bias,
hidden_states
)
return output, None
def _submodule_attention_forward(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
): ):
# todo node.common_state.residual = node.detach(residual)
inference_context = deprecate_inference_params(inference_context, inference_params) if self.mlp.use_shared_expert:
node.common_state.pre_mlp_layernorm_output = node.detach(pre_mlp_layernorm_output)
# Residual connection.
residual = hidden_states
# Optional Input Layer norm return tokens_per_expert, permutated_local_input_tokens, permuted_probs
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
nvtx_range_push(suffix="self_attention")
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
nvtx_range_pop(suffix="self_attention")
if self.recompute_input_layernorm: def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs, state=None):
# discard the output of the input layernorm and register the recompute """
# as a gradient hook of attention_output_with_bias[0] Dispatches tokens to the appropriate experts based on the router output.
self.input_layernorm_checkpoint.discard_output_and_register_recompute( """
attention_output_with_bias[0] token_dispatcher = self.mlp.token_dispatcher
) tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
# TODO: could we move `bias_dropout_add_exec_handler` itself return tokens_per_expert, global_input_tokens, global_probs
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="self_attn_bda")
return hidden_states def _submodule_dispatch_postprocess(self, node, tokens_per_expert, global_input_tokens, global_probs):
return tokens_per_expert, global_input_tokens, global_probs
def _submodule_attention_router_compound_forward( def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, state=None):
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
""" """
Performs a combined forward pass that includes self-attention and MLP routing logic. Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
""" """
hidden_states = self._submodule_attention_forward( shared_expert_output = None
hidden_states, token_dispatcher = self.mlp.token_dispatcher
attention_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
inference_context,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
)
# Optional Layer norm post the cross-attention. dispatched_input, tokens_per_expert, permuted_probs = token_dispatcher.dispatch_postprocess(
if self.recompute_pre_mlp_layernorm: tokens_per_expert, global_input_tokens, global_probs
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
) )
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output) expert_output, mlp_bias = self.mlp.experts(
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare( dispatched_tokens, tokens_per_expert, permuted_probs
pre_mlp_layernorm_output, probs, routing_map
)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, tokens_per_expert
) )
assert mlp_bias is None, f"Bias is not supported in {token_dispatcher.__class__.__name__}"
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
assert state is not None
shared_expert_output = self.mlp.shared_experts(state.pre_mlp_layernorm_output)
outputs = [ expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
hidden_states, return expert_output, shared_expert_output, mlp_bias
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
permuted_probs,
]
return tuple(outputs)
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs): def _submodule_mlp_postprocess(self, node, expert_output, shared_expert_output, mlp_bias):
""" assert mlp_bias is None
Dispatches tokens to the appropriate experts based on the router output. node.common_state.pre_mlp_layernorm_output = None
""" if shared_expert_output is None:
tokens_per_expert, global_input_tokens, global_probs = self.mlp.token_dispatcher.dispatch_all_to_all( return expert_output
tokens_per_expert, permutated_local_input_tokens, permuted_probs return expert_output, shared_expert_output
)
return [tokens_per_expert, global_input_tokens, global_probs] def _submodule_combine_forward(self, expert_output, shared_expert_output=None, state=None):
residual = state.residual
token_dispatcher = self.mlp.token_dispatcher
permutated_local_input_tokens = token_dispatcher.combine_all_to_all(expert_output)
output = token_dispatcher.combine_postprocess(permutated_local_input_tokens)
if shared_expert_output is not None:
output = output + shared_expert_output
def _submodule_dense_forward(self, hidden_states): mlp_output_with_bias = (output, None)
residual = hidden_states
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout mlp_output_with_bias, residual, self.hidden_dropout
...@@ -413,71 +311,112 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -413,71 +311,112 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return output return output
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, pre_mlp_layernorm_output): def _submodule_combine_postprocess(self, node, output):
cur_stream = torch.cuda.current_stream()
node.common_state.residual.record_stream(cur_stream)
node.common_state.residual = None
return output
def _submodule_attn_router_dw(self):
self.self_attention.backward_dw()
def _submodule_mlp_dw(self):
self.mlp.backward_dw()
def _submodule_attn_postprocess(self, node, pre_mlp_layernorm_output, residual, context):
return pre_mlp_layernorm_output, residual
def _submodule_dense_postprocess(self, node, hidden_states):
return hidden_states
def _submodule_not_implemented(self, *args):
raise NotImplementedError("This callable is not implemented.")
def get_submodule_callables(self, chunk_state):
""" """
Performs a forward pass for the MLP submodule, including both expert-based The forward callables take 2 parts of inputs:
and optional shared-expert computations. 1. The ScheduleNode object.
2. The input tensors.
""" """
shared_expert_output = None from megatron.core.transformer.moe.moe_layer import MoELayer
(dispatched_input, tokens_per_expert, permuted_probs) = ( from megatron.core.transformer.moe.token_dispatcher import MoEFlexTokenDispatcher
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
) self.is_moe = isinstance(self.mlp, MoELayer)
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert, permuted_probs) self.is_deepep = False
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output) if self.is_moe:
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap: self.is_deepep = isinstance(self.mlp.token_dispatcher, MoEFlexTokenDispatcher)
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations def get_func_with_default(func, default_func):
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output) if self.is_moe:
return expert_output, shared_expert_output, mlp_bias return func
return default_func
def callable_wrapper(forward_func, postprocess_func, node, *args):
state = getattr(node, 'common_state', None)
callable_outputs = forward_func(*args, state=state)
if isinstance(callable_outputs, tuple):
outputs = postprocess_func(node, *callable_outputs)
else:
outputs = postprocess_func(node, callable_outputs)
return outputs
def _submodule_combine_forward(self, hidden_states): attn_func = get_func_with_default(
return [self.mlp.token_dispatcher.combine_all_to_all(hidden_states)] self._submodule_attn_router_forward, self._forward_attention
)
def _submodule_post_combine_forward( def attn_wrapper(hidden_states, state=None):
self, expert_output, shared_expert_output, mlp_bias, residual
):
""" """
Re-combines the expert outputs (and optional shared_expert_output) into the same order state (Any, optional): Placeholder for submodule callable wrapper.
as the original input tokens, applying any required bias.
""" """
output = self.mlp.token_dispatcher.combine_postprocess(expert_output) return attn_func(
if shared_expert_output is not None: hidden_states=hidden_states,
output += shared_expert_output attention_mask=chunk_state.attention_mask,
mlp_output_with_bias = (output, mlp_bias) content=chunk_state.context,
context_mask=chunk_state.context_mask,
if self.recompute_pre_mlp_layernorm: rotary_pos_emb=chunk_state.rotary_pos_emb,
# discard the output of the pre-mlp layernorm and register the recompute rotary_pos_cos=chunk_state.rotary_pos_cos,
# as a gradient hook of mlp_output_with_bias[0] rotary_pos_sin=chunk_state.rotary_pos_sin,
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute( attention_bias=chunk_state.attention_bias,
mlp_output_with_bias[0] inference_context=chunk_state.inference_context,
packed_seq_params=chunk_state.packed_seq_params,
sequence_len_offset=chunk_state.sequence_len_offset,
inference_params=chunk_state.inference_params,
) )
# TODO: could we move `bias_dropout_add_exec_handler` itself attn_postprocess_func = get_func_with_default(
# inside the module provided in the `bias_dropout_add_spec` module? self._submodule_attn_router_postprocess, self._submodule_attn_postprocess
nvtx_range_push(suffix="mlp_bda")
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
) )
nvtx_range_pop(suffix="mlp_bda")
dispatch_func = get_func_with_default(
# Jit compiled function creates 'view' tensor. This tensor self._submodule_dispatch_forward, self._submodule_not_implemented
# potentially gets saved in the MPU checkpoint function context, )
# which rejects view tensors. While making a viewless tensor here dispatch_postprocess_func = get_func_with_default(
# won't result in memory savings (like the data loader, or self._submodule_dispatch_postprocess, self._submodule_not_implemented
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
) )
return output mlp_func = get_func_with_default(self._submodule_moe_forward, self._forward_mlp)
mlp_postprocess_func = get_func_with_default(
self._submodule_mlp_postprocess, self._submodule_dense_postprocess
)
def _submodule_attention_dw(self): combine_func = get_func_with_default(
self.self_attention.backward_dw() self._submodule_combine_forward, self._submodule_not_implemented
)
combine_postprocess_func = get_func_with_default(
self._submodule_combine_postprocess, self._submodule_not_implemented
)
def _submodule_attention_router_compound_dw(self): attn_forward = partial(callable_wrapper, attn_wrapper, attn_postprocess_func)
self._submodule_attention_dw() dispatch_forward = partial(callable_wrapper, dispatch_func, dispatch_postprocess_func)
mlp_forward = partial(callable_wrapper, mlp_func, mlp_postprocess_func)
combine_forward = partial(callable_wrapper, combine_func, combine_postprocess_func)
def _submodule_mlp_dw(self): callables = TransformerLayerSubmoduleCallables(
self.mlp.backward_dw() attention=SubmoduleCallables(forward=attn_forward, dw=self._submodule_attn_router_dw),
dispatch=SubmoduleCallables(forward=dispatch_forward),
mlp=SubmoduleCallables(forward=mlp_forward, dw=self._submodule_mlp_dw),
combine=SubmoduleCallables(forward=combine_forward),
is_moe=self.is_moe,
is_deepep=self.is_deepep,
)
return callables
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment