Commit 43770f8e authored by dongcl's avatar dongcl
Browse files

bug fix

parent b85974a6
......@@ -89,9 +89,12 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
def patch_core_models(self):
from ..core.models.gpt.gpt_model import gpt_model_forward
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, gpt_model_forward
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward',
gpt_model_forward)
......@@ -171,8 +174,6 @@ class CoreAdaptation(MegatronAdaptationABC):
FluxRowParallelLinear)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers",
FluxColumnParallelLinear)
def patch_pipeline_parallel(self):
pass
......
......@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.torch_norm import L2Norm
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
......@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_layer_with_flux_spec(
......@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
qk_l2_norm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use flux modules (required for fp8 training).
......@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
Returns:
ModuleSpec: Module specification with flux modules
......@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
)
if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
......@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
linear_qkv=FluxColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=FluxRowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
q_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
k_layernorm=(
L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp)
),
),
),
self_attn_bda=get_bias_dropout_add,
......
......@@ -2,14 +2,48 @@ from collections import OrderedDict
from typing import Optional
from functools import wraps
import os
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
def gpt_model_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# Output
if (
(self.post_process or self.mtp_process)
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size,
self.vocab_size,
config=self.config,
init_method=self.config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
return wrapper
def gpt_model_forward(
self,
......
......@@ -24,7 +24,7 @@ from megatron.core.tensor_parallel.mappings import (
)
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
RowParallelLinear
)
from megatron.core.tensor_parallel.layers import (
custom_fwd,
......@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super(FluxColumnParallelLinear, self).__init__(
input_size=input_size,
......@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
disable_grad_reduce=disable_grad_reduce,
tp_group=tp_group,
)
# flux params
......@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super(FluxRowParallelLinear, self).__init__(
......@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
# flux params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment