Commit 83c871fb authored by wanglong3's avatar wanglong3
Browse files

feat: support shared expert fusion.

parent cca00f5c
...@@ -921,12 +921,12 @@ class rocm_aiter_ops: ...@@ -921,12 +921,12 @@ class rocm_aiter_ops:
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod @classmethod
@if_aiter_supported # @if_aiter_supported
def is_fused_moe_enabled(cls) -> bool: def is_fused_moe_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FMOE_ENABLED return cls._AITER_ENABLED and cls._FMOE_ENABLED
@classmethod @classmethod
@if_aiter_supported # @if_aiter_supported
def is_fusion_moe_shared_experts_enabled(cls) -> bool: def is_fusion_moe_shared_experts_enabled(cls) -> bool:
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
......
...@@ -1055,7 +1055,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1055,7 +1055,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use aiter triton fp4 bmm kernel # Whether to use aiter triton fp4 bmm kernel
# By default is enabled. # By default is enabled.
"VLLM_ROCM_USE_AITER_FP4BMM": lambda: ( "VLLM_ROCM_USE_AITER_FP4BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1") os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "False").lower() in ("true", "1")
), ),
# Use AITER triton unified attention for V1 attention # Use AITER triton unified attention for V1 attention
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
......
...@@ -335,8 +335,10 @@ class GroupedTopKRouter(BaseRouter): ...@@ -335,8 +335,10 @@ class GroupedTopKRouter(BaseRouter):
rocm_aiter_grouped_topk, rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
) )
enable_shared_experts_fusion = True
else: else:
grouped_topk_impl = grouped_topk grouped_topk_impl = grouped_topk
enable_shared_experts_fusion = False
if self.use_fused_gate: if self.use_fused_gate:
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
...@@ -347,7 +349,7 @@ class GroupedTopKRouter(BaseRouter): ...@@ -347,7 +349,7 @@ class GroupedTopKRouter(BaseRouter):
self.num_expert_group, self.num_expert_group,
self.topk_group, self.topk_group,
self.top_k, self.top_k,
0, self.num_fused_shared_experts if enable_shared_experts_fusion else 0,
self.routed_scaling_factor, self.routed_scaling_factor,
) )
else: else:
...@@ -358,7 +360,7 @@ class GroupedTopKRouter(BaseRouter): ...@@ -358,7 +360,7 @@ class GroupedTopKRouter(BaseRouter):
self.topk_group, self.topk_group,
self.top_k, self.top_k,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
n_share_experts_fusion=0, n_share_experts_fusion = (self.num_fused_shared_experts if enable_shared_experts_fusion else 0),
) )
else: else:
topk_weights, topk_ids = grouped_topk_impl( topk_weights, topk_ids = grouped_topk_impl(
......
...@@ -324,8 +324,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -324,8 +324,12 @@ class DeepseekV2MoE(nn.Module):
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
gate=self.gate, gate=self.gate,
num_experts=config.n_routed_experts, # num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, # top_k=config.num_experts_per_tok,
num_experts=config.n_routed_experts
+ (config.n_shared_experts if self.is_fusion_moe_shared_experts_enabled else 0),
top_k = config.num_experts_per_tok
+ (config.n_shared_experts if self.is_fusion_moe_shared_experts_enabled else 0),
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=False, reduce_results=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment