Unverified Commit df5192cf authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Enable fast silu-and-mul-and-quant fused kernel (#11806)

parent 78c43d88
...@@ -39,6 +39,9 @@ if not (_is_npu or _is_hip): ...@@ -39,6 +39,9 @@ if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul from sgl_kernel import silu_and_mul
_MASKED_GEMM_FAST_ACT = get_bool_env_var("SGLANG_MASKED_GEMM_FAST_ACT")
# TODO(kaixih@nvidia): ideally we should merge this logic into # TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile @torch.compile
...@@ -214,6 +217,9 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -214,6 +217,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
) )
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_8bit,
)
hidden_states = runner_input.hidden_states hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale hidden_states_scale = runner_input.hidden_states_scale
...@@ -258,6 +264,20 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -258,6 +264,20 @@ class DeepGemmRunnerCore(MoeRunnerCore):
dispose_tensor(hidden_states_scale) dispose_tensor(hidden_states_scale)
# Act # Act
scale_block_size = 128
if _MASKED_GEMM_FAST_ACT:
down_input, down_input_scale = sglang_per_token_group_quant_8bit(
x=gateup_output,
dst_dtype=torch.float8_e4m3fn,
group_size=scale_block_size,
masked_m=masked_m,
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
fuse_silu_and_mul=True,
enable_v2=True,
)
else:
down_input = torch.empty( down_input = torch.empty(
( (
gateup_output.shape[0], gateup_output.shape[0],
...@@ -267,7 +287,6 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -267,7 +287,6 @@ class DeepGemmRunnerCore(MoeRunnerCore):
device=hidden_states_device, device=hidden_states_device,
dtype=torch.float8_e4m3fn, dtype=torch.float8_e4m3fn,
) )
scale_block_size = 128
down_input_scale = torch.empty( down_input_scale = torch.empty(
( (
gateup_output.shape[0], gateup_output.shape[0],
......
...@@ -459,7 +459,7 @@ def create_per_token_group_quant_fp8_output_scale( ...@@ -459,7 +459,7 @@ def create_per_token_group_quant_fp8_output_scale(
x_shape[:-2] + (x_shape[-1] // group_size, aligned_size), x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
device=device, device=device,
dtype=torch.float32, dtype=torch.float32,
).permute(-1, -2)[: x_shape[-2], :] ).transpose(-1, -2)[: x_shape[-2], :]
else: else:
return torch.empty( return torch.empty(
(x_shape[-1] // group_size,) + x_shape[:-1], (x_shape[-1] // group_size,) + x_shape[:-1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment