Commit 7c9dc3ec authored by dongcl's avatar dongcl
Browse files

forward_backward_pipelining_without_interleaving supports a2a_overlap

parent 649bfbdb
......@@ -5,6 +5,8 @@ import types
import argparse
import torch
from megatron.core.utils import is_te_min_version
class MegatronAdaptation:
"""
......@@ -89,14 +91,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
def patch_core_models(self):
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, GPTModel
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel',
GPTModel)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
......@@ -116,9 +118,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False}),
apply_wrapper=True)
# MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
# torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
# apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute',
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
......@@ -132,12 +134,25 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..core.extensions.transformer_engine import TEDotProductAttentionPatch
from megatron.core.extensions.transformer_engine import TEGroupedLinear
# kv channels, te_min_version 1.10.0 -> 1.9.0
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
TEDotProductAttentionPatch.__init__)
if not is_te_min_version("1.10.0"):
# kv channels, te_min_version 1.10.0 -> 1.9.0
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
TEDotProductAttentionPatch.__init__)
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,)
TEGroupedLinear.__bases__ = (te.pytorch.BatchedLinear if is_te_min_version("2.3.0.dev0") else te.pytorch.BatchLinear,)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
......@@ -162,7 +177,7 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel import (
from ..core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
......
......@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.torch_norm import L2Norm
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
......@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_layer_with_flux_spec(
......@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
qk_l2_norm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use flux modules (required for fp8 training).
......@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
Returns:
ModuleSpec: Module specification with flux modules
......@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
)
if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
......@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
linear_qkv=FluxColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=FluxRowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
q_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
k_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
),
),
self_attn_bda=get_bias_dropout_add,
......
......@@ -13,8 +13,6 @@ from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn):
@wraps(fn)
......@@ -22,12 +20,13 @@ def gpt_model_init_wrapper(fn):
fn(self, *args, **kwargs)
# Output
if self.post_process or self.mtp_process:
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
parallel_linear_impl = FluxColumnParallelLinear
else:
parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = parallel_linear_impl(
if (
(self.post_process or self.mtp_process)
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size,
self.vocab_size,
config=self.config,
......@@ -41,8 +40,8 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
return wrapper
......
from .layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear,
)
\ No newline at end of file
......@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super(FluxColumnParallelLinear, self).__init__(
input_size=input_size,
......@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
disable_grad_reduce=disable_grad_reduce,
tp_group=tp_group,
)
# flux params
......@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super(FluxRowParallelLinear, self).__init__(
......@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
# flux params
......
......@@ -23,6 +23,7 @@ def transformer_config_post_init_wrapper(fn):
##################
self.flux_transpose_weight = args.flux_transpose_weight
return wrapper
......@@ -33,6 +34,12 @@ class ExtraTransformerConfig:
##################
flux_transpose_weight: bool = False
combined_1f1b: bool = False
"""If true, use combined 1F1B for communication hiding."""
combined_1f1b_recipe: str = 'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
@dataclass
class TransformerConfigPatch(TransformerConfig, ExtraTransformerConfig):
......
......@@ -26,6 +26,8 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
parser = _add_extra_training_args(parser)
parser = _add_extra_distributed_args(parser)
parser = _add_extra_tokenizer_args(parser)
parser = _add_extra_moe_args(parser)
parser = _add_flux_args(parser)
return parser
......@@ -128,6 +130,18 @@ def _add_extra_tokenizer_args(parser):
return parser
def _add_extra_moe_args(parser):
group = parser.add_argument_group(title="extra moe args")
group.add_argument('--combined-1f1b', action='store_true',
help='Batch-level overlapping in 1f1b stage.')
group.add_argument('--combined-1f1b-recipe', type=str,
choices=['ep_a2a', 'golden'],
default='golden',
help='Options are "ep_a2a" and "golden".')
return parser
def _add_flux_args(parser):
group = parser.add_argument_group(title='flux args')
group.add_argument('--flux-transpose-weight', action='store_true', default=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment