Unverified Commit df5192cf authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Enable fast silu-and-mul-and-quant fused kernel (#11806)

parent 78c43d88
......@@ -39,6 +39,9 @@ if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
_MASKED_GEMM_FAST_ACT = get_bool_env_var("SGLANG_MASKED_GEMM_FAST_ACT")
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
......@@ -214,6 +217,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd,
)
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_8bit,
)
hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
......@@ -258,33 +264,46 @@ class DeepGemmRunnerCore(MoeRunnerCore):
dispose_tensor(hidden_states_scale)
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
if _MASKED_GEMM_FAST_ACT:
down_input, down_input_scale = sglang_per_token_group_quant_8bit(
x=gateup_output,
dst_dtype=torch.float8_e4m3fn,
group_size=scale_block_size,
masked_m=masked_m,
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
fuse_silu_and_mul=True,
enable_v2=True,
)
else:
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=torch.float8_e4m3fn,
)
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
......
......@@ -459,7 +459,7 @@ def create_per_token_group_quant_fp8_output_scale(
x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
device=device,
dtype=torch.float32,
).permute(-1, -2)[: x_shape[-2], :]
).transpose(-1, -2)[: x_shape[-2], :]
else:
return torch.empty(
(x_shape[-1] // group_size,) + x_shape[:-1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment