Commit 7c9dc3ec authored by dongcl's avatar dongcl
Browse files

forward_backward_pipelining_without_interleaving supports a2a_overlap

parent 649bfbdb
...@@ -5,6 +5,8 @@ import types ...@@ -5,6 +5,8 @@ import types
import argparse import argparse
import torch import torch
from megatron.core.utils import is_te_min_version
class MegatronAdaptation: class MegatronAdaptation:
""" """
...@@ -89,14 +91,14 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -89,14 +91,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass pass
def patch_core_models(self): def patch_core_models(self):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, GPTModel
# GPT Model # GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper, gpt_model_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel',
gpt_model_forward) GPTModel)
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper from ..core import transformer_block_init_wrapper
...@@ -116,9 +118,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -116,9 +118,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True) apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}), torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True) apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
...@@ -132,12 +134,25 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -132,12 +134,25 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.extensions.transformer_engine import TEDotProductAttentionPatch from ..core.extensions.transformer_engine import TEDotProductAttentionPatch
from megatron.core.extensions.transformer_engine import TEGroupedLinear from megatron.core.extensions.transformer_engine import TEGroupedLinear
# kv channels, te_min_version 1.10.0 -> 1.9.0 if not is_te_min_version("1.10.0"):
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__', # kv channels, te_min_version 1.10.0 -> 1.9.0
TEDotProductAttentionPatch.__init__) MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
TEDotProductAttentionPatch.__init__)
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')): if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,) TEGroupedLinear.__bases__ = (te.pytorch.BatchedLinear if is_te_min_version("2.3.0.dev0") else te.pytorch.BatchLinear,)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
...@@ -162,7 +177,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -162,7 +177,7 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux # flux
if int(os.getenv("USE_FLUX_OVERLAP", "0")): if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel import ( from ..core.tensor_parallel.layers import (
FluxColumnParallelLinear, FluxColumnParallelLinear,
FluxRowParallelLinear FluxRowParallelLinear
) )
......
...@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import ( ...@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules, MLASelfAttentionSubmodules,
) )
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.torch_norm import L2Norm
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import ( from megatron.core.transformer.transformer_layer import (
...@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import ( ...@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear, FluxColumnParallelLinear,
FluxRowParallelLinear FluxRowParallelLinear
) )
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_layer_with_flux_spec( def get_gpt_layer_with_flux_spec(
...@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
multi_latent_attention: Optional[bool] = False, multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False,
qk_l2_norm: Optional[bool] = False,
) -> ModuleSpec: ) -> ModuleSpec:
"""Use this spec to use flux modules (required for fp8 training). """Use this spec to use flux modules (required for fp8 training).
...@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
fp8 (str, optional): Deprecated. For temporary Nemo compatibility. fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False. Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
Returns: Returns:
ModuleSpec: Module specification with flux modules ModuleSpec: Module specification with flux modules
...@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
) )
if multi_latent_attention: if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
return ModuleSpec( return ModuleSpec(
module=TransformerLayer, module=TransformerLayer,
submodules=TransformerLayerSubmodules( submodules=TransformerLayerSubmodules(
...@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec( ...@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
linear_qkv=FluxColumnParallelLinear, linear_qkv=FluxColumnParallelLinear,
core_attention=TEDotProductAttention, core_attention=TEDotProductAttention,
linear_proj=FluxRowParallelLinear, linear_proj=FluxRowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp, q_layernorm=(
k_layernorm=qk_norm if qk_layernorm else IdentityOp, L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
k_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
), ),
), ),
self_attn_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
......
...@@ -13,8 +13,6 @@ from megatron.core.inference.contexts import BaseInferenceContext ...@@ -13,8 +13,6 @@ from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn): def gpt_model_init_wrapper(fn):
@wraps(fn) @wraps(fn)
...@@ -22,12 +20,13 @@ def gpt_model_init_wrapper(fn): ...@@ -22,12 +20,13 @@ def gpt_model_init_wrapper(fn):
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
# Output # Output
if self.post_process or self.mtp_process: if (
if int(os.getenv("USE_FLUX_OVERLAP", "0")): (self.post_process or self.mtp_process)
parallel_linear_impl = FluxColumnParallelLinear and int(os.getenv("USE_FLUX_OVERLAP", "0"))
else: ):
parallel_linear_impl = tensor_parallel.ColumnParallelLinear from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear
self.output_layer = parallel_linear_impl(
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size, self.config.hidden_size,
self.vocab_size, self.vocab_size,
config=self.config, config=self.config,
...@@ -41,8 +40,8 @@ def gpt_model_init_wrapper(fn): ...@@ -41,8 +40,8 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer=self.grad_output_buffer, grad_output_buffer=self.grad_output_buffer,
) )
if self.pre_process or self.post_process: if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer() self.setup_embeddings_and_output_layer()
return wrapper return wrapper
......
from .layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear,
)
\ No newline at end of file
...@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False, disable_grad_reduce: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
super(FluxColumnParallelLinear, self).__init__( super(FluxColumnParallelLinear, self).__init__(
input_size=input_size, input_size=input_size,
...@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert=is_expert, is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name, tp_comm_buffer_name=tp_comm_buffer_name,
disable_grad_reduce=disable_grad_reduce, disable_grad_reduce=disable_grad_reduce,
tp_group=tp_group,
) )
# flux params # flux params
...@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
keep_master_weight_for_test: bool = False, keep_master_weight_for_test: bool = False,
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used tp_comm_buffer_name: str = None, # Not used
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
super(FluxRowParallelLinear, self).__init__( super(FluxRowParallelLinear, self).__init__(
...@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
stride=stride, stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test, keep_master_weight_for_test=keep_master_weight_for_test,
is_expert=is_expert, is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
) )
# flux params # flux params
......
...@@ -23,6 +23,7 @@ def transformer_config_post_init_wrapper(fn): ...@@ -23,6 +23,7 @@ def transformer_config_post_init_wrapper(fn):
################## ##################
self.flux_transpose_weight = args.flux_transpose_weight self.flux_transpose_weight = args.flux_transpose_weight
return wrapper return wrapper
...@@ -33,6 +34,12 @@ class ExtraTransformerConfig: ...@@ -33,6 +34,12 @@ class ExtraTransformerConfig:
################## ##################
flux_transpose_weight: bool = False flux_transpose_weight: bool = False
combined_1f1b: bool = False
"""If true, use combined 1F1B for communication hiding."""
combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
@dataclass @dataclass
class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig): class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig):
......
...@@ -26,6 +26,8 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser): ...@@ -26,6 +26,8 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
parser = _add_extra_training_args(parser) parser = _add_extra_training_args(parser)
parser = _add_extra_distributed_args(parser) parser = _add_extra_distributed_args(parser)
parser = _add_extra_tokenizer_args(parser) parser = _add_extra_tokenizer_args(parser)
parser = _add_extra_moe_args(parser)
parser = _add_flux_args(parser)
return parser return parser
...@@ -128,6 +130,18 @@ def _add_extra_tokenizer_args(parser): ...@@ -128,6 +130,18 @@ def _add_extra_tokenizer_args(parser):
return parser return parser
def _add_extra_moe_args(parser):
group = parser.add_argument_group(title="extra moe args")
group.add_argument('--combined-1f1b', action='store_true',
help='Batch-level overlapping in 1f1b stage.')
group.add_argument('--combined-1f1b-recipe', type=str,
choices=['ep_a2a', 'golden'],
default='golden',
help='Options are "ep_a2a" and "golden".')
return parser
def _add_flux_args(parser): def _add_flux_args(parser):
group = parser.add_argument_group(title='flux args') group = parser.add_argument_group(title='flux args')
group.add_argument('--flux-transpose-weight', action='store_true', default=False, group.add_argument('--flux-transpose-weight', action='store_true', default=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment