Commit 4bb958ec authored by dongcl's avatar dongcl
Browse files

support a2a_overlap

parent 7c9dc3ec
from megatron.core.utils import is_te_min_version
def a2a_overlap_adaptation(patches_manager):
"""
patches_manager: MegatronPatchesManager
"""
from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from ..core.transformer.transformer_block import TransformerBlock
from ..core.transformer.transformer_layer import TransformerLayer
from ..core.models.gpt.gpt_model import GPTModel
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
from ..core.extensions.transformer_engine import _get_extra_te_kwargs_wrapper, TELinear, TELayerNormColumnParallelLinear
from ..core.transformer.multi_latent_attention import MLASelfAttention
from ..core.transformer.mlp import MLP
from ..core.transformer.moe.experts import TEGroupedMLP
from ..core.transformer.moe.moe_layer import MoELayer
# num_warmup_microbatches + 1
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher)
patches_manager.register_patch('megatron.core.transformer.transformer_block.TransformerBlock',
TransformerBlock)
patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel',
GPTModel)
# backward_dw
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
TELayerNormColumnParallelLinear)
if is_te_min_version("1.9.0.dev0"):
from ..core.extensions.transformer_engine import TEGroupedLinear
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
TEGroupedLinear)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention',
MLASelfAttention)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP',
MLP)
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP',
TEGroupedMLP)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer',
MoELayer)
......@@ -24,6 +24,13 @@ class MegatronAdaptation:
adaptation.execute()
MegatronAdaptation.apply()
# apply features
from .patch_utils import MegatronPatchesManager
from .features_manager import a2a_overlap_adaptation
a2a_overlap_adaptation(MegatronPatchesManager)
MegatronPatchesManager.apply_patches()
@classmethod
def register(cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False, apply_wrapper=False, remove_origin_wrappers=False):
"""
......@@ -91,14 +98,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
def patch_core_models(self):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, GPTModel
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel',
GPTModel)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
......@@ -142,18 +149,6 @@ class CoreAdaptation(MegatronAdaptationABC):
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchedLinear if is_te_min_version("2.3.0.dev0") else te.pytorch.BatchLinear,)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
......@@ -190,9 +185,6 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_pipeline_parallel(self):
pass
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
......
......@@ -3,7 +3,7 @@ import torch
import dataclasses
import transformer_engine as te
from typing import Any, Optional
from typing import Any, Optional, Callable
from packaging.version import Version as PkgVersion
from megatron.core.packed_seq_params import PackedSeqParams
......@@ -13,6 +13,9 @@ from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.extensions.transformer_engine import TELinear as MegatronCoreTELinear
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
......@@ -22,6 +25,112 @@ from megatron.core.parallel_state import (
)
def _get_extra_te_kwargs_wrapper(fn):
@wraps(fn)
def wrapper(config: TransformerConfig):
extra_transformer_engine_kwargs = fn(config)
extra_transformer_engine_kwargs["delay_wgrad_compute"] = config.get("split_bw", False)
return extra_transformer_engine_kwargs
return wrapper
class TELinear(MegatronCoreTELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
input_size,
output_size,
parallel_mode=parallel_mode,
config=config,
init_method=init_method,
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
is_expert=is_expert,
tp_group=tp_group,
)
def backward_dw(self):
if not self.split_bw:
return
class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
input_size,
output_size,
config=config,
init_method=init_method,
gather_output=gather_output,
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
def backward_dw(self):
if not self.split_bw:
return
class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__(
self,
......@@ -176,3 +285,52 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
layer_number=layer_number,
**extra_kwargs,
)
if is_te_min_version("1.9.0.dev0"):
from megatron.core.extensions.transformer_engine import TEGroupedLinear as MegatronCoreTEGroupedLinear
class TEGroupedLinear(MegatronCoreTEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.split_bw = config.get("split_bw", False)
assert not self.split_bw, "split_bw is currently not supported"
super().__init__(
num_gemms,
input_size,
output_size,
parallel_mode=parallel_mode,
config=config,
init_method=init_method,
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
def backward_dw(self):
if not self.split_bw:
return
......@@ -239,7 +239,6 @@ class PostProcessNode(ScheduleNode):
return loss
class TransformerLayerNode(ScheduleNode):
def __init__(self, chunk_state, common_state, layer, stream, event, free_inputs=False):
......@@ -598,8 +597,6 @@ def schedule_layer_1f1b(
with f_context:
f_input = f_layer.mlp.forward(f_input)
def next_iter_pre_forward():
if f_layer is not None:
with f_context:
......
......@@ -46,18 +46,7 @@ def gpt_model_init_wrapper(fn):
return wrapper
class GPTModel(MegatronCoreGPTModel):
"""
patch megatron GPTModel
"""
def get_transformer_callables_by_layer(self, layer_number: int):
"""
Get the callables for the layer at the given transformer layer number.
"""
return self.decoder.get_layer_callables(layer_number)
def build_schedule_plan(
def gpt_model_forward(
self,
input_ids: Tensor,
position_ids: Tensor,
......@@ -71,66 +60,7 @@ class GPTModel(MegatronCoreGPTModel):
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
):
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
Returns:
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
from .fine_grained_schedule import build_model_chunk_schedule_plan
return build_model_chunk_schedule_plan(
self,
input_ids,
position_ids,
attention_mask,
decoder_input=decoder_input,
labels=labels,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
extra_block_kwargs=extra_block_kwargs,
runtime_gather_output=runtime_gather_output,
inference_params=inference_params,
loss_mask=loss_mask,
)
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
......@@ -300,3 +230,74 @@ class GPTModel(MegatronCoreGPTModel):
loss = self.compute_language_model_loss(labels, logits)
return loss
class GPTModel(MegatronCoreGPTModel):
"""
patch megatron GPTModel
"""
def get_transformer_callables_by_layer(self, layer_number: int):
"""
Get the callables for the layer at the given transformer layer number.
"""
return self.decoder.get_layer_callables(layer_number)
def build_schedule_plan(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
):
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
Returns:
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
from .fine_grained_schedule import build_model_chunk_schedule_plan
return build_model_chunk_schedule_plan(
self,
input_ids,
position_ids,
attention_mask,
decoder_input=decoder_input,
labels=labels,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
extra_block_kwargs=extra_block_kwargs,
runtime_gather_output=runtime_gather_output,
inference_params=inference_params,
loss_mask=loss_mask,
)
......@@ -503,6 +503,7 @@ def get_default_cls_for_unwrap():
pass
return cls
def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
return_list = True
......
from megatron.core.transformer.mlp import MLP as MegatronCoreMLP
class MLP(MegatronCoreMLP):
def backward_dw(self):
self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw()
\ No newline at end of file
from megatron.core.transformer.experts import TEGroupedMLP as MegatronCoreTEGroupedMLP
class TEGroupedMLP(MegatronCoreTEGroupedMLP):
def backward_dw(self):
self.linear_fc2.backward_dw()
self.linear_fc1.backward_dw()
from megatron.core.transformer.moe.moe_layer import MoELayer as MegatronCoreMoELayer
class MoELayer(MegatronCoreMoELayer):
def backward_dw(self):
self.experts.backward_dw()
self.shared_experts.backward_dw()
from megatron.core.transformer.multi_latent_attention import MLASelfAttention as MegatronCoreMLASelfAttention
class MLASelfAttention(MegatronCoreMLASelfAttention):
"""MLA Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def backward_dw(self):
self.linear_kv_up_proj.backward_dw()
self.linear_kv_down_proj.backward_dw()
if self.config.q_lora_rank is None:
self.linear_q_proj.backward_dw()
else:
self.linear_q_down_proj.backward_dw()
self.linear_q_up_proj.backward_dw()
self.linear_proj.backward_dw()
......@@ -17,16 +17,6 @@ def transformer_block_init_wrapper(fn):
class TransformerBlock(MegatronCoreTransformerBlock):
def __init__(
self, *args, **kwargs
):
super().__init__(*args, **kwargs)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
def get_layer_callables(self, layer_number: int):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment