Unverified Commit 0e3fe896 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Support Llama 4 for fused_marlin_moe (#20457)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 1caca5a5
...@@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
quant_type_id: int, quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None, global_scale1: Optional[torch.Tensor] = None,
...@@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights, topk_weights,
moe_block_size=block_size_m, moe_block_size=block_size_m,
top_k=topk, top_k=topk,
mul_topk_weights=False, mul_topk_weights=apply_router_weight_on_input,
is_ep=expert_map is not None, is_ep=expert_map is not None,
b_q_type=quant_type, b_q_type=quant_type,
size_m=M, size_m=M,
...@@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights, topk_weights,
moe_block_size=block_size_m, moe_block_size=block_size_m,
top_k=1, top_k=1,
mul_topk_weights=True, mul_topk_weights=not apply_router_weight_on_input,
is_ep=expert_map is not None, is_ep=expert_map is not None,
b_q_type=quant_type, b_q_type=quant_type,
size_m=M * topk, size_m=M * topk,
...@@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, ...@@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
quant_type_id: int, quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
global_scale1: Optional[torch.Tensor] = None, global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None, global_scale2: Optional[torch.Tensor] = None,
......
...@@ -493,11 +493,6 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -493,11 +493,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -520,6 +515,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -520,6 +515,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
......
...@@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ...@@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
global_scale1=layer.w13_weight_scale_2, global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2, global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map)
...@@ -669,8 +670,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -669,8 +670,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.use_marlin: if self.use_marlin:
assert activation == "silu", ( assert activation == "silu", (
f"{activation} not supported for Marlin MoE.") f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -681,6 +680,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -681,6 +680,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map)
...@@ -1356,8 +1356,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1356,8 +1356,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
assert activation == "silu", ( assert activation == "silu", (
f"{activation} not supported for Marlin MoE.") f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -1381,6 +1379,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1381,6 +1379,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
g_idx1=layer.w13_weight_g_idx, g_idx1=layer.w13_weight_g_idx,
......
...@@ -889,8 +889,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -889,8 +889,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
elif self.use_marlin: elif self.use_marlin:
assert activation == "silu", ( assert activation == "silu", (
f"{activation} not supported for Marlin MoE.") f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
...@@ -901,6 +899,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -901,6 +899,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id, quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map)
else: else:
......
...@@ -645,10 +645,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -645,10 +645,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"EPLB not supported for `GPTQMarlinMoEMethod` yet.") "EPLB not supported for `GPTQMarlinMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -672,6 +668,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -672,6 +668,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights, topk_weights,
topk_ids, topk_ids,
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
g_idx1=layer.w13_g_idx, g_idx1=layer.w13_g_idx,
......
...@@ -700,6 +700,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -700,6 +700,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale1=layer.w13_weight_scale_2, global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2, global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id, quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment