Commit 0639678c authored by laibao's avatar laibao
Browse files

feat(moe): 增加 LightOP moe_sum+mul+add 融合并打通参数透传

  新增环境变量 VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD 用于控制
  fused sum+mul+add 开关。
  在 DeepseekV2MoE 中增加 fused 路径,预计算 shared_output,并下传 iqis 与 routed_scaling_factor。
  扩展 FusedMoE/SharedFusedMoE 及相关 custom op 接口,统一透传 i_q/i_s/shared_output/routed_scaling_factor。
  同步适配 Triton、Marlin W16A16、SlimQuant W4A8、CompressedTensors W8A8 等实现,支持在内核侧完成 sum+mul+add。
parent efa6bed2
...@@ -313,6 +313,7 @@ if TYPE_CHECKING: ...@@ -313,6 +313,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
...@@ -1957,6 +1958,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1957,6 +1958,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_USE_CUDA_GRAPH_SIZES", "False").lower() in lambda: (os.getenv("VLLM_USE_CUDA_GRAPH_SIZES", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop fused moe_sum + mul + add (bias + factor)
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD",
"False").lower() in ("true", "1")),
#If set to 1/True, enable fused topk topk kernel in lightop #If set to 1/True, enable fused topk topk kernel in lightop
"VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK": "VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK", "False").lower() in
......
...@@ -397,7 +397,22 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -397,7 +397,22 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
) )
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K) intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
if envs.VLLM_USE_LIGHTOP_MOE_SUM: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD and shared_output is not None:
from lightop import op as op
factor = (
float(routed_scaling_factor)
if routed_scaling_factor is not None
else 1.0
)
op.moe_sum(
input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx],
bias=shared_output[begin_chunk_idx:end_chunk_idx],
expert_mask=None,
num_local_tokens=None,
factor=factor,
)
elif envs.VLLM_USE_LIGHTOP_MOE_SUM:
from lightop import op as op from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()), op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None, output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
...@@ -406,4 +421,4 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -406,4 +421,4 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states return out_hidden_states
\ No newline at end of file
...@@ -1404,6 +1404,8 @@ def inplace_fused_experts( ...@@ -1404,6 +1404,8 @@ def inplace_fused_experts(
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None: ) -> None:
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
...@@ -1433,6 +1435,8 @@ def inplace_fused_experts( ...@@ -1433,6 +1435,8 @@ def inplace_fused_experts(
w1_bias, w1_bias,
w2_bias, w2_bias,
use_nn_moe, use_nn_moe,
shared_output,
routed_scaling_factor,
) )
...@@ -1463,6 +1467,8 @@ def inplace_fused_experts_fake( ...@@ -1463,6 +1467,8 @@ def inplace_fused_experts_fake(
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None: ) -> None:
pass pass
...@@ -1508,7 +1514,9 @@ def outplace_fused_experts( ...@@ -1508,7 +1514,9 @@ def outplace_fused_experts(
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -1540,6 +1548,8 @@ def outplace_fused_experts( ...@@ -1540,6 +1548,8 @@ def outplace_fused_experts(
use_nn_moe, use_nn_moe,
i_q=i_q, i_q=i_q,
i_s=i_s, i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -1550,6 +1560,7 @@ def outplace_fused_experts_fake( ...@@ -1550,6 +1560,7 @@ def outplace_fused_experts_fake(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
...@@ -1569,6 +1580,8 @@ def outplace_fused_experts_fake( ...@@ -1569,6 +1580,8 @@ def outplace_fused_experts_fake(
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1618,7 +1631,9 @@ def fused_experts( ...@@ -1618,7 +1631,9 @@ def fused_experts(
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
if quant_config is None: if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
...@@ -1652,6 +1667,8 @@ def fused_experts( ...@@ -1652,6 +1667,8 @@ def fused_experts(
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
i_q=i_q, i_q=i_q,
i_s=i_s, i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -1712,7 +1729,9 @@ def fused_experts_impl( ...@@ -1712,7 +1729,9 @@ def fused_experts_impl(
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
# Check constraints. # Check constraints.
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
...@@ -1820,6 +1839,8 @@ def fused_experts_impl( ...@@ -1820,6 +1839,8 @@ def fused_experts_impl(
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
use_nn_moe=False, use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
) )
if use_nn_moe: if use_nn_moe:
...@@ -2283,6 +2304,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2283,6 +2304,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
): ):
# Check constraints. # Check constraints.
if self.quant_config.use_int4_w4a16: if self.quant_config.use_int4_w4a16:
...@@ -2416,10 +2439,33 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2416,10 +2439,33 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
# separate function is required for MoE + LoRA # separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output) self.moe_sum(
intermediate_cache3,
output,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def moe_sum(
self,
input: torch.Tensor,
output: torch.Tensor,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> None:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD and shared_output is not None:
from lightop import op
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: op.moe_sum(
ops.moe_sum(input, output) input=input,
output=output,
bias=shared_output,
expert_mask=None,
num_local_tokens=None,
factor=float(routed_scaling_factor),
)
else:
ops.moe_sum(input, output)
class TritonWNA16Experts(TritonExperts): class TritonWNA16Experts(TritonExperts):
......
...@@ -1670,7 +1670,9 @@ class FusedMoE(CustomOp): ...@@ -1670,7 +1670,9 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1] og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states: if self.hidden_size != og_hidden_states:
...@@ -1709,7 +1711,13 @@ class FusedMoE(CustomOp): ...@@ -1709,7 +1711,13 @@ class FusedMoE(CustomOp):
assert not isinstance(fused_output, tuple) assert not isinstance(fused_output, tuple)
else: else:
fused_output = torch.ops.vllm.moe_forward( fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, encode_layer_name() hidden_states,
router_logits,
encode_layer_name(),
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
return reduce_output(fused_output)[..., :og_hidden_states] return reduce_output(fused_output)[..., :og_hidden_states]
else: else:
...@@ -1723,13 +1731,21 @@ class FusedMoE(CustomOp): ...@@ -1723,13 +1731,21 @@ class FusedMoE(CustomOp):
else: else:
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None: if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared( shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name(), hidden_states,
router_logits,
encode_layer_name(),
i_q=i_q, i_q=i_q,
i_s=i_s i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
else: else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared( shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name() hidden_states,
router_logits,
encode_layer_name(),
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
return ( return (
reduce_output(shared_output)[..., :og_hidden_states], reduce_output(shared_output)[..., :og_hidden_states],
...@@ -1747,12 +1763,18 @@ class FusedMoE(CustomOp): ...@@ -1747,12 +1763,18 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None: return self.forward_native(
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s) hidden_states,
else: router_logits,
return self.forward_native(hidden_states, router_logits) i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def forward_impl_chunked( def forward_impl_chunked(
self, self,
...@@ -1895,10 +1917,11 @@ class FusedMoE(CustomOp): ...@@ -1895,10 +1917,11 @@ class FusedMoE(CustomOp):
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None assert self.quant_method is not None
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
self.ensure_dp_chunking_init() self.ensure_dp_chunking_init()
...@@ -2026,21 +2049,25 @@ class FusedMoE(CustomOp): ...@@ -2026,21 +2049,25 @@ class FusedMoE(CustomOp):
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=x, # The type signture of this is wrong due to the hack. x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
i_q=i_q, i_q=i_q,
i_s=i_s i_s=i_s,
) shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else: else:
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=x, # The type signture of this is wrong due to the hack. x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe use_nn_moe=self.use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
if has_separate_shared_experts: if has_separate_shared_experts:
...@@ -2164,11 +2191,20 @@ def moe_forward( ...@@ -2164,11 +2191,20 @@ def moe_forward(
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is None assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s) return self.forward_impl(
hidden_states,
router_logits,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def moe_forward_fake( def moe_forward_fake(
...@@ -2176,7 +2212,9 @@ def moe_forward_fake( ...@@ -2176,7 +2212,9 @@ def moe_forward_fake(
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -2195,14 +2233,28 @@ def moe_forward_shared( ...@@ -2195,14 +2233,28 @@ def moe_forward_shared(
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is not None assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None: if envs.USE_FUSED_RMS_QUANT and i_q is not None and i_s is not None:
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s) return self.forward_impl(
hidden_states,
router_logits,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else: else:
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(
hidden_states,
router_logits,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def moe_forward_shared_fake( def moe_forward_shared_fake(
...@@ -2210,7 +2262,9 @@ def moe_forward_shared_fake( ...@@ -2210,7 +2262,9 @@ def moe_forward_shared_fake(
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states) fused_out = torch.empty_like(hidden_states)
......
...@@ -61,7 +61,10 @@ class SharedFusedMoE(FusedMoE): ...@@ -61,7 +61,10 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped: if not self.use_overlapped:
if self._shared_experts is not None: if self._shared_experts is not None:
...@@ -92,12 +95,16 @@ class SharedFusedMoE(FusedMoE): ...@@ -92,12 +95,16 @@ class SharedFusedMoE(FusedMoE):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=iqis[0], i_q=iqis[0],
i_s=iqis[1] i_s=iqis[1],
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
else: else:
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits router_logits=router_logits,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
else: else:
if envs.USE_FUSED_RMS_QUANT and iqis is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
...@@ -107,12 +114,16 @@ class SharedFusedMoE(FusedMoE): ...@@ -107,12 +114,16 @@ class SharedFusedMoE(FusedMoE):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=iqis[0], i_q=iqis[0],
i_s=iqis[1] i_s=iqis[1],
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
else: else:
shared_out, fused_out = super().forward( shared_out, fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits router_logits=router_logits,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
# ensure early TP reduction of shared expert outputs when required # ensure early TP reduction of shared expert outputs when required
if ( if (
...@@ -122,4 +133,4 @@ class SharedFusedMoE(FusedMoE): ...@@ -122,4 +133,4 @@ class SharedFusedMoE(FusedMoE):
and self.must_reduce_shared_expert_outputs() and self.must_reduce_shared_expert_outputs()
): ):
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out return shared_out, fused_out
\ No newline at end of file
...@@ -370,7 +370,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -370,7 +370,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, **_ use_nn_moe: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
**_,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward( return self.forward(
layer=layer, layer=layer,
...@@ -378,6 +381,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -378,6 +381,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
...@@ -397,6 +402,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -397,6 +402,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( if (
getattr(layer, "_marlin_w16a16_moe_enabled", False) getattr(layer, "_marlin_w16a16_moe_enabled", False)
...@@ -415,6 +422,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -415,6 +422,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=layer.expert_map, expert_map=layer.expert_map,
quant_config=self.get_fused_moe_quant_config(layer), quant_config=self.get_fused_moe_quant_config(layer),
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
assert self.kernel is not None assert self.kernel is not None
...@@ -430,6 +439,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -430,6 +439,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
def forward_monolithic_cpu( def forward_monolithic_cpu(
......
...@@ -400,7 +400,9 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -400,7 +400,9 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -424,7 +426,9 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -424,7 +426,9 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False, use_nn_moe=False,
i_q=i_q, i_q=i_q,
i_s=i_s i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
def select_gemm_impl( def select_gemm_impl(
...@@ -461,4 +465,4 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -461,4 +465,4 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return DeepGemmExperts(moe_config=self.moe, return DeepGemmExperts(moe_config=self.moe,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
N=self.N, N=self.N,
K=self.K) K=self.K)
\ No newline at end of file
...@@ -225,9 +225,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -225,9 +225,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin( return fused_experts_impl_w4a8_marlin(
x, x,
...@@ -246,5 +247,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -246,5 +247,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale a2_scale=layer.w2_input_scale,
) shared_output=shared_output,
\ No newline at end of file routed_scaling_factor=routed_scaling_factor,
)
...@@ -367,35 +367,64 @@ class DeepseekV2MoE(nn.Module): ...@@ -367,35 +367,64 @@ class DeepseekV2MoE(nn.Module):
if self.is_sequence_parallel: if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router: needs_post_moe_combine = (
# In this case, the gate/router runs inside the FusedMoE class getattr(self.experts, "dp_size", 1) > 1
fused_moe_out = self.experts( or getattr(self.experts, "pcp_size", 1) > 1
hidden_states=hidden_states, router_logits=hidden_states, )
iqis=iqis
) if (
else: envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
# router_logits: (num_tokens, n_experts) and self.shared_experts is not None
and not needs_post_moe_combine
):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( shared_output = self.shared_experts(hidden_states, iqis=iqis)
hidden_states=hidden_states, router_logits=router_logits routed_scaling_factor = (
1.0 if self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor
)
self.experts.use_overlapped = False
self.experts._shared_experts = None
# Marlin W16A16 fused reduce consumes the precomputed
# shared_output and routed_scaling_factor directly.
_, final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
iqis=iqis,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
else:
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states,
router_logits=hidden_states,
iqis=iqis,
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
shared_output, final_hidden_states = fused_moe_out # Fix FP16 overflow
if self.shared_experts is None: # See DeepseekV2DecoderLayer for more details.
assert shared_output is None if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled:
# Fix FP16 overflow final_hidden_states *= self.routed_scaling_factor
# See DeepseekV2DecoderLayer for more details. elif self.shared_experts is not None:
if hidden_states.dtype != torch.float16: assert shared_output is not None
if not self.is_rocm_aiter_moe_enabled: shared_output *= 1.0 / self.routed_scaling_factor
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None: if self.shared_experts is not None:
assert shared_output is not None assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor final_hidden_states += shared_output
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel: if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment