Unverified Commit 00c7b1ad authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Let EP prefill support new DeepGEMM (#7310)

parent 82eccae4
...@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import ( from sglang.srt.utils import (
DeepEPMode, DeepEPMode,
ceil_div,
dispose_tensor, dispose_tensor,
get_bool_env_var, get_bool_env_var,
is_hip, is_hip,
...@@ -1370,10 +1371,19 @@ class DeepEPMoE(EPMoE): ...@@ -1370,10 +1371,19 @@ class DeepEPMoE(EPMoE):
device=hidden_states_fp8.device, device=hidden_states_fp8.device,
dtype=hidden_states_fp8.dtype, dtype=hidden_states_fp8.dtype,
), ),
torch.empty( (
(all_tokens, K // 128), # TODO check whether need `zeros`
device=hidden_states_fp8.device, torch.zeros(
dtype=torch.float32, (ceil_div(K // 128, 4), all_tokens),
device=hidden_states_fp8.device,
dtype=torch.int,
).transpose(0, 1)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
)
), ),
] ]
m_indices = torch.empty( m_indices = torch.empty(
...@@ -1399,6 +1409,7 @@ class DeepEPMoE(EPMoE): ...@@ -1399,6 +1409,7 @@ class DeepEPMoE(EPMoE):
input_tensor[1], input_tensor[1],
m_indices, m_indices,
output_index, output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
) )
dispose_tensor(hidden_states_fp8) dispose_tensor(hidden_states_fp8)
...@@ -1407,7 +1418,8 @@ class DeepEPMoE(EPMoE): ...@@ -1407,7 +1418,8 @@ class DeepEPMoE(EPMoE):
device=hidden_states_fp8_device, device=hidden_states_fp8_device,
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
input_tensor[1] = tma_align_input_scale(input_tensor[1]) if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig( deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices input_tensor, self.w13_weight_fp8, gateup_output, m_indices
) )
...@@ -1428,10 +1440,15 @@ class DeepEPMoE(EPMoE): ...@@ -1428,10 +1440,15 @@ class DeepEPMoE(EPMoE):
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input, scale_block_size down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
) )
del down_input del down_input
down_input_scale = tma_align_input_scale(down_input_scale) if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig( deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale), (down_input_fp8, down_input_scale),
self.w2_weight_fp8, self.w2_weight_fp8,
......
...@@ -246,7 +246,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -246,7 +246,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx = topk_idx.to(torch.int64) topk_idx = topk_idx.to(torch.int64)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
# TODO hard code 128 block quant,use fp8 communication # TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128) hidden_states = sglang_per_token_group_quant_fp8(
hidden_states,
128,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_idx, topk_weights, previous_event return hidden_states, topk_idx, topk_weights, previous_event
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment