Commit 28edf80d authored by maxiao1's avatar maxiao1
Browse files

低延迟模式支持dispatch int8

parent 59259b56
......@@ -1001,7 +1001,7 @@ class DeepEPMoE(EPMoE):
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, _, masked_m, expected_m = dispatch_output
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
......@@ -1010,8 +1010,7 @@ class DeepEPMoE(EPMoE):
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ----
w13_weight = self.w13_weight
w13_scales = self.w13_weight_scale
......@@ -1023,7 +1022,7 @@ class DeepEPMoE(EPMoE):
# ---- first GEMM ----
torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
q_a1_all, q_a1_scale,
hidden_states, hidden_states_scale,
w13_weight, w13_scales,
gateup_output,
masked_m,
......@@ -1051,7 +1050,7 @@ class DeepEPMoE(EPMoE):
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, topk_ids, _, masked_m, expected_m = dispatch_output
hidden_states, hidden_states_scale, topk_ids, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# base shapes
......@@ -1059,7 +1058,7 @@ class DeepEPMoE(EPMoE):
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ----
w13_weight = self.w13_weight
......@@ -1072,7 +1071,7 @@ class DeepEPMoE(EPMoE):
# ---- first GEMM ----
torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
q_a1_all, q_a1_scale,
hidden_states, hidden_states_scale,
w13_weight, w13_scales,
gateup_output,
masked_m,
......
......@@ -623,7 +623,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_ids,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=False,
use_fp8=True,
round_scale=False,
use_ue8m0=False,
use_int8=True,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment