Commit 74306deb authored by 王敏's avatar 王敏
Browse files

[fix]修复nn_moe启动报错

parent ee58c1bf
......@@ -2373,10 +2373,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
assert hidden_states.size(-1) // 2 == w1.size(2) if not use_nn_moe else w1.size(1), "Hidden size mismatch"
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
expect_hidden_size = w1.size(2) if not use_nn_moe else w1.size(1)
assert hidden_states.size(-1) == expect_hidden_size, (
f"Hidden size mismatch {hidden_states.size(-1)} != {expect_hidden_size}"
)
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
......@@ -2395,6 +2396,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states, w1, w2, topk_ids
)
if use_nn_moe:
N = w1.size(-1)
if global_num_experts == -1:
global_num_experts = E
......@@ -2405,6 +2409,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
use_nn_moe=use_nn_moe,
)
if hidden_states.dtype == torch.bfloat16:
......
......@@ -1131,11 +1131,15 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids
)
if use_nn_moe:
N = w1.size(2)
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
......@@ -1206,6 +1210,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe
)
return fused_out
......@@ -1289,6 +1294,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
......@@ -1350,6 +1356,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
)
return self._finalize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment