Commit ec7c8bc3 authored by dongcl's avatar dongcl
Browse files

replace te with flux when using flux

parent 138b70a2
......@@ -190,27 +190,17 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux
if os.getenv("USE_FLUX_OVERLAP", 0):
import flux
from ..core.tensor_parallel import (
ColumnParallelLinearPatch,
RowParallelLinearPatch,
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from ..core.models.gpt.gpt_layer_specs import get_gpt_layer_with_flux_spec
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
column_parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
row_parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec",
MegatronAdaptation.register("megatron.core.extensions.transformer_engine.TEColumnParallelLinear",
FluxColumnParallelLinear)
MegatronAdaptation.register("megatron.core.extensions.transformer_engine.TERowParallelLinear",
FluxRowParallelLinear)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_training(self):
......
......@@ -3,7 +3,6 @@ from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
......@@ -17,6 +16,9 @@ from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version
try:
......@@ -79,13 +81,13 @@ def get_gpt_layer_with_flux_spec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=ColumnParallelLinear,
linear_q_down_proj=ColumnParallelLinear,
linear_q_up_proj=ColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear,
linear_kv_up_proj=ColumnParallelLinear,
linear_q_proj=FluxColumnParallelLinear,
linear_q_down_proj=FluxColumnParallelLinear,
linear_q_up_proj=FluxColumnParallelLinear,
linear_kv_down_proj=FluxColumnParallelLinear,
linear_kv_up_proj=FluxColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear,
linear_proj=FluxRowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp,
),
......@@ -111,9 +113,9 @@ def get_gpt_layer_with_flux_spec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
linear_qkv=FluxColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear,
linear_proj=FluxRowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
......@@ -145,8 +147,8 @@ def get_mlp_module_flux_spec(
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
linear_fc1=FluxColumnParallelLinear,
linear_fc2=FluxRowParallelLinear,
),
)
else:
......
from .layers import (
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper,
ColumnParallelLinearPatch,
RowParallelLinearPatch,
FluxColumnParallelLinear,
FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init,
)
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment