Commit 43770f8e authored by dongcl's avatar dongcl
Browse files

bug fix

parent b85974a6
...@@ -89,9 +89,12 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -89,9 +89,12 @@ class CoreAdaptation(MegatronAdaptationABC):
pass pass
def patch_core_models(self): def patch_core_models(self):
from ..core.models.gpt.gpt_model import gpt_model_forward from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model # GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward) gpt_model_forward)
...@@ -171,8 +174,6 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -171,8 +174,6 @@ class CoreAdaptation(MegatronAdaptationABC):
FluxRowParallelLinear) FluxRowParallelLinear)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec", MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec) get_gpt_layer_with_flux_spec)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers",
FluxColumnParallelLinear)
def patch_pipeline_parallel(self): def patch_pipeline_parallel(self):
pass pass
......
...@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import ( ...@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules, MLASelfAttentionSubmodules,
) )
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.torch_norm import L2Norm
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import ( from megatron.core.transformer.transformer_layer import (
...@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import ( ...@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear, FluxColumnParallelLinear,
FluxRowParallelLinear FluxRowParallelLinear
) )
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_layer_with_flux_spec( def get_gpt_layer_with_flux_spec(
...@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
multi_latent_attention: Optional[bool] = False, multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False,
qk_l2_norm: Optional[bool] = False,
) -> ModuleSpec: ) -> ModuleSpec:
"""Use this spec to use flux modules (required for fp8 training). """Use this spec to use flux modules (required for fp8 training).
...@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
fp8 (str, optional): Deprecated. For temporary Nemo compatibility. fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False. Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
Returns: Returns:
ModuleSpec: Module specification with flux modules ModuleSpec: Module specification with flux modules
...@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec( ...@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
) )
if multi_latent_attention: if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
return ModuleSpec( return ModuleSpec(
module=TransformerLayer, module=TransformerLayer,
submodules=TransformerLayerSubmodules( submodules=TransformerLayerSubmodules(
...@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec( ...@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
linear_qkv=FluxColumnParallelLinear, linear_qkv=FluxColumnParallelLinear,
core_attention=TEDotProductAttention, core_attention=TEDotProductAttention,
linear_proj=FluxRowParallelLinear, linear_proj=FluxRowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp, q_layernorm=(
k_layernorm=qk_norm if qk_layernorm else IdentityOp, L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
k_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
), ),
), ),
self_attn_bda=get_bias_dropout_add, self_attn_bda=get_bias_dropout_add,
......
...@@ -2,14 +2,48 @@ from collections import OrderedDict ...@@ -2,14 +2,48 @@ from collections import OrderedDict
from typing import Optional from typing import Optional
from functools import wraps from functools import wraps
import os
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.core.utils import WrappedTensor, deprecate_inference_params
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# Output
if (
(self.post_process or self.mtp_process)
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
return wrapper
def gpt_model_forward( def gpt_model_forward(
self, self,
......
...@@ -24,7 +24,7 @@ from megatron.core.tensor_parallel.mappings import ( ...@@ -24,7 +24,7 @@ from megatron.core.tensor_parallel.mappings import (
) )
from megatron.core.tensor_parallel import ( from megatron.core.tensor_parallel import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear
) )
from megatron.core.tensor_parallel.layers import ( from megatron.core.tensor_parallel.layers import (
custom_fwd, custom_fwd,
...@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False, disable_grad_reduce: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
super(FluxColumnParallelLinear, self).__init__( super(FluxColumnParallelLinear, self).__init__(
input_size=input_size, input_size=input_size,
...@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear): ...@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert=is_expert, is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name, tp_comm_buffer_name=tp_comm_buffer_name,
disable_grad_reduce=disable_grad_reduce, disable_grad_reduce=disable_grad_reduce,
tp_group=tp_group,
) )
# flux params # flux params
...@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
keep_master_weight_for_test: bool = False, keep_master_weight_for_test: bool = False,
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used tp_comm_buffer_name: str = None, # Not used
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
super(FluxRowParallelLinear, self).__init__( super(FluxRowParallelLinear, self).__init__(
...@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear): ...@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
stride=stride, stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test, keep_master_weight_for_test=keep_master_weight_for_test,
is_expert=is_expert, is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
) )
# flux params # flux params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment