Commit 4a4fb3de authored by zhuwenwen's avatar zhuwenwen
Browse files

fix nn_moe run error

parent 530e785f
...@@ -2179,10 +2179,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2179,10 +2179,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
): ):
# Check constraints. # Check constraints.
if self.quant_config.use_int4_w4a16: if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" assert hidden_states.size(-1) // 2 == w1.size(2) if not use_nn_moe else w1.size(1), "Hidden size mismatch"
else: else:
assert hidden_states.size(-1) == w1.size(2), ( expect_hidden_size = w1.size(2) if not use_nn_moe else w1.size(1)
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}" assert hidden_states.size(-1) == expect_hidden_size, (
f"Hidden size mismatch {hidden_states.size(-1)} != {expect_hidden_size}"
) )
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
...@@ -2200,6 +2201,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2200,6 +2201,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
E, num_tokens, N, K, top_k_num = self.moe_problem_size( E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids hidden_states, w1, w2, topk_ids
) )
if use_nn_moe:
N = w1.size(-1)
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
...@@ -2211,6 +2215,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2211,6 +2215,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.quant_config.config_name(hidden_states.dtype), self.quant_config.config_name(hidden_states.dtype),
num_tokens, num_tokens,
block_shape=self.block_shape, block_shape=self.block_shape,
use_nn_moe=use_nn_moe,
) )
if hidden_states.dtype == torch.bfloat16: if hidden_states.dtype == torch.bfloat16:
......
...@@ -1920,14 +1920,14 @@ class FusedMoE(CustomOp): ...@@ -1920,14 +1920,14 @@ class FusedMoE(CustomOp):
if self.capture is not None: if self.capture is not None:
self.capture(topk_ids) self.capture(topk_ids)
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=x, # The type signture of this is wrong due to the hack. x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
# use_fused_gate=self.use_fused_gate, )
)
if has_separate_shared_experts: if has_separate_shared_experts:
assert self.shared_experts is not None assert self.shared_experts is not None
......
...@@ -1131,11 +1131,15 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1131,11 +1131,15 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
) -> torch.Tensor: ) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size( _, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids a1q, w1, w2, topk_ids
) )
if use_nn_moe:
N = w1.size(2)
num_chunks, CHUNK_SIZE = self._chunk_info(M_full) num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
def input_chunk_range(chunk_idx: int) -> tuple[int, int]: def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
...@@ -1206,6 +1210,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1206,6 +1210,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
) )
return fused_out return fused_out
...@@ -1289,6 +1294,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1289,6 +1294,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -1350,6 +1356,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1350,6 +1356,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
) )
return self._finalize( return self._finalize(
......
...@@ -316,7 +316,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -316,7 +316,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.kernel is not None
return self.kernel( return self.kernel(
......
...@@ -1248,6 +1248,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1248,6 +1248,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -1263,6 +1264,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1263,6 +1264,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
use_nn_moe=use_nn_moe,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment