Unverified Commit 0e3fe896 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Support Llama 4 for fused_marlin_moe (#20457)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 1caca5a5
......@@ -24,6 +24,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
global_scale1: Optional[torch.Tensor] = None,
......@@ -149,7 +150,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
mul_topk_weights=apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M,
......@@ -182,7 +183,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=True,
mul_topk_weights=not apply_router_weight_on_input,
is_ep=expert_map is not None,
b_q_type=quant_type,
size_m=M * topk,
......@@ -208,6 +209,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
global_scale1: Optional[torch.Tensor] = None,
global_scale2: Optional[torch.Tensor] = None,
......
......@@ -493,11 +493,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
......@@ -520,6 +515,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
......
......@@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
......@@ -669,8 +670,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
......@@ -681,6 +680,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
......@@ -1356,8 +1356,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......@@ -1381,6 +1379,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_weight_g_idx,
......
......@@ -889,8 +889,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
elif self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")
assert not apply_router_weight_on_input, (
"Apply router weight on input not supported for Marlin MoE.")
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
......@@ -901,6 +899,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
else:
......
......@@ -645,10 +645,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......@@ -672,6 +668,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
......
......@@ -700,6 +700,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment