Unverified Commit df5192cf authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Enable fast silu-and-mul-and-quant fused kernel (#11806)

parent 78c43d88
...@@ -39,6 +39,9 @@ if not (_is_npu or _is_hip): ...@@ -39,6 +39,9 @@ if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul from sgl_kernel import silu_and_mul
_MASKED_GEMM_FAST_ACT = get_bool_env_var("SGLANG_MASKED_GEMM_FAST_ACT")
# TODO(kaixih@nvidia): ideally we should merge this logic into # TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile @torch.compile
...@@ -214,6 +217,9 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -214,6 +217,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
) )
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_8bit,
)
hidden_states = runner_input.hidden_states hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale hidden_states_scale = runner_input.hidden_states_scale
...@@ -258,33 +264,46 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -258,33 +264,46 @@ class DeepGemmRunnerCore(MoeRunnerCore):
dispose_tensor(hidden_states_scale) dispose_tensor(hidden_states_scale)
# Act # Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128 scale_block_size = 128
down_input_scale = torch.empty( if _MASKED_GEMM_FAST_ACT:
( down_input, down_input_scale = sglang_per_token_group_quant_8bit(
gateup_output.shape[0], x=gateup_output,
gateup_output.shape[1], dst_dtype=torch.float8_e4m3fn,
gateup_output.shape[2] // 2 // scale_block_size, group_size=scale_block_size,
), masked_m=masked_m,
device=hidden_states_device, column_major_scales=True,
dtype=torch.float32, scale_tma_aligned=True,
) scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
silu_and_mul_masked_post_quant_fwd( fuse_silu_and_mul=True,
gateup_output, enable_v2=True,
down_input, )
down_input_scale, else:
scale_block_size, down_input = torch.empty(
masked_m, (
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, gateup_output.shape[0],
) gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=torch.float8_e4m3fn,
)
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output del gateup_output
# GroupGemm-1 # GroupGemm-1
......
...@@ -459,7 +459,7 @@ def create_per_token_group_quant_fp8_output_scale( ...@@ -459,7 +459,7 @@ def create_per_token_group_quant_fp8_output_scale(
x_shape[:-2] + (x_shape[-1] // group_size, aligned_size), x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
device=device, device=device,
dtype=torch.float32, dtype=torch.float32,
).permute(-1, -2)[: x_shape[-2], :] ).transpose(-1, -2)[: x_shape[-2], :]
else: else:
return torch.empty( return torch.empty(
(x_shape[-1] // group_size,) + x_shape[:-1], (x_shape[-1] // group_size,) + x_shape[:-1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment