Unverified Commit 00c7b1ad authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Let EP prefill support new DeepGEMM (#7310)

parent 82eccae4
......@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import (
DeepEPMode,
ceil_div,
dispose_tensor,
get_bool_env_var,
is_hip,
......@@ -1370,10 +1371,19 @@ class DeepEPMoE(EPMoE):
device=hidden_states_fp8.device,
dtype=hidden_states_fp8.dtype,
),
torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
(
# TODO check whether need `zeros`
torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states_fp8.device,
dtype=torch.int,
).transpose(0, 1)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
)
),
]
m_indices = torch.empty(
......@@ -1399,6 +1409,7 @@ class DeepEPMoE(EPMoE):
input_tensor[1],
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states_fp8)
......@@ -1407,7 +1418,8 @@ class DeepEPMoE(EPMoE):
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
input_tensor[1] = tma_align_input_scale(input_tensor[1])
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
)
......@@ -1428,10 +1440,15 @@ class DeepEPMoE(EPMoE):
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input, scale_block_size
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input
down_input_scale = tma_align_input_scale(down_input_scale)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
self.w2_weight_fp8,
......
......@@ -246,7 +246,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx = topk_idx.to(torch.int64)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
# TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
hidden_states = sglang_per_token_group_quant_fp8(
hidden_states,
128,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, previous_event
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment