Commit 8d7f3017 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'

低延迟模式支持dispatch int8

See merge request OpenDAS/sglang!41
parents e3c76844 28edf80d
......@@ -999,7 +999,7 @@ class DeepEPMoE(EPMoE):
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, _, masked_m, expected_m = dispatch_output
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
......@@ -1008,8 +1008,7 @@ class DeepEPMoE(EPMoE):
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ----
w13_weight = self.w13_weight
w13_scales = self.w13_weight_scale
......@@ -1021,7 +1020,7 @@ class DeepEPMoE(EPMoE):
# ---- first GEMM ----
torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
q_a1_all, q_a1_scale,
hidden_states, hidden_states_scale,
w13_weight, w13_scales,
gateup_output,
masked_m,
......@@ -1049,7 +1048,7 @@ class DeepEPMoE(EPMoE):
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, topk_ids, _, masked_m, expected_m = dispatch_output
hidden_states, hidden_states_scale, topk_ids, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# base shapes
......@@ -1057,7 +1056,7 @@ class DeepEPMoE(EPMoE):
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ----
w13_weight = self.w13_weight
......@@ -1070,7 +1069,7 @@ class DeepEPMoE(EPMoE):
# ---- first GEMM ----
torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
q_a1_all, q_a1_scale,
hidden_states, hidden_states_scale,
w13_weight, w13_scales,
gateup_output,
masked_m,
......
......@@ -616,7 +616,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_ids,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=False,
use_fp8=True,
round_scale=False,
use_ue8m0=False,
use_int8=True,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment