Unverified Commit 99456bca authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[perf] introduce deep gemm group_gemm_masked as bmm (#5432)

parent d07e797a
......@@ -44,6 +44,7 @@ else:
fp8_min = -fp8_max
_enable_jit_deepgemm = False
_enable_jit_deepgemm_bmm = False
if _is_cuda:
import deep_gemm
from sgl_kernel import (
......@@ -53,10 +54,11 @@ if _is_cuda:
)
sm_version = get_device_sm()
if sm_version == 90 and get_bool_env_var(
"SGL_ENABLE_JIT_DEEPGEMM", default="false"
):
_enable_jit_deepgemm = True
if sm_version == 90:
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
_enable_jit_deepgemm = True
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
_enable_jit_deepgemm_bmm = True
logger = logging.getLogger(__name__)
......@@ -940,6 +942,108 @@ def per_tensor_quant_mla_fp8(
return x_q, x_s_out
@triton.jit
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
y_ptr,
y_q_ptr,
y_s_ptr,
masked_m_ptr,
group_size,
y_stride_b,
y_stride_t,
y_q_stride_b,
y_q_stride_t,
y_s_stride_b,
y_s_stride_g,
eps,
fp8_min,
fp8_max,
NUM_GROUP: tl.constexpr,
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor for deep_gemm grouped_gemm_masked.
This function converts the tensor values into float8 values.
y and y_q: (b, t, k)
y_s: (b, k//group_size, t)
"""
t_id = tl.program_id(0)
b_id = tl.program_id(1)
y_ptr += b_id * y_stride_b + t_id * y_stride_t
y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
y_s_ptr += b_id * y_s_stride_b + t_id
if t_id == 0:
tl.store(masked_m_ptr + b_id, tl.num_programs(0))
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
for gid in range(NUM_GROUP):
y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
tl.float32
)
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
def per_tensor_quant_mla_deep_gemm_masked_fp8(
x: torch.Tensor,
group_size: int = 128,
eps: float = 1e-12,
dtype: torch.dtype = torch.float8_e4m3fn,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with per-token-group-quantization
for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
b, m, k = x.shape
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
num_tiles_k = k // group_size
assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
masked_m = x.new_empty((b,), dtype=torch.int32)
BLOCK_SIZE = triton.next_power_of_2(group_size)
grid = (m, b)
_per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
x,
x_q,
x_s,
masked_m,
group_size,
x.stride(0),
x.stride(1),
x_q.stride(0),
x_q.stride(1),
x_s.stride(0),
x_s.stride(1),
eps,
-fp8_max,
fp8_max,
num_tiles_k,
BLOCK_SIZE,
)
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
......
......@@ -57,7 +57,11 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm_bmm,
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
......@@ -82,6 +86,7 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
else:
from vllm._custom_ops import awq_dequantize
......@@ -530,6 +535,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc = None
self.w_scale = None
self.w_scale_k = None
self.w_scale_v = None
self.use_deep_gemm_bmm = False
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
......@@ -684,7 +693,24 @@ class DeepseekV2AttentionMLA(nn.Module):
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.w_kc.dtype == torch.float8_e4m3fnuz:
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
)
)
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
masked_m,
expected_m,
)
q_nope_out = q_nope_out[:, :expected_m, :]
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
......@@ -716,7 +742,24 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fnuz:
if self.use_deep_gemm_bmm:
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
)
)
attn_bmm_output = attn_output.new_empty(
(self.num_local_heads, aligned_m, self.v_head_dim)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(attn_output_val, attn_output_scale),
(self.w_vc, self.w_scale_v),
attn_bmm_output,
masked_m,
expected_m,
)
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
......@@ -1439,6 +1482,10 @@ class DeepseekV2ForCausalLM(nn.Module):
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
model_dtype = torch.get_default_dtype()
if w.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
......@@ -1457,10 +1504,20 @@ class DeepseekV2ForCausalLM(nn.Module):
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
if (
_is_cuda
and _enable_jit_deepgemm_bmm
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
block_scale = weight_scale
use_deep_gemm_bmm = True
else:
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale
......@@ -1483,18 +1540,31 @@ class DeepseekV2ForCausalLM(nn.Module):
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if _is_hip:
self_attn.w_scale *= 2.0
if not use_deep_gemm_bmm:
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if _is_hip:
self_attn.w_scale *= 2.0
else:
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
ws_kc, ws_vc = block_scale.unflatten(
0, (-1, (num_tiles_k + num_tiles_n))
).split([num_tiles_k, num_tiles_n], dim=1)
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
self_attn.w_scale_v = ws_vc.contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
self_attn.w_vc = w_vc.contiguous()
self_attn.use_deep_gemm_bmm = True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -7,6 +7,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
per_token_group_quant_fp8,
static_quant_fp8,
......@@ -212,6 +213,62 @@ class TestPerTensorQuantMlaFP8(CustomTestCase):
self._per_tensor_quant_mla_fp8(*params)
class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
B = [128]
NUM_TOKENS = [7, 83, 2048, 1024 * 16]
D = [512, 128]
GROUP_SIZE = [128]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
self, b, num_tokens, d, dtype, group_size, seed
):
torch.manual_seed(seed)
x = torch.rand(b, num_tokens, d, dtype=dtype)
with torch.inference_mode():
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size, 1e-12)
out, scale, _, _, _ = per_tensor_quant_mla_deep_gemm_masked_fp8(
x, group_size
)
out = out[:, :num_tokens, :]
scale = scale[:, :num_tokens, :]
self.assertTrue(
torch.allclose(
out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20, atol=1e-2
)
)
self.assertTrue(torch.allclose(scale, ref_scale))
def test_per_token_group_quant_mla_deep_gemm_masked_fp8(self):
for params in itertools.product(
self.B,
self.NUM_TOKENS,
self.D,
self.DTYPES,
self.GROUP_SIZE,
self.SEEDS,
):
with self.subTest(
b=params[0],
num_tokens=params[1],
d=params[2],
dtype=params[3],
group_size=params[4],
seed=params[5],
):
self._per_token_group_quant_mla_deep_gemm_masked_fp8(*params)
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
......@@ -485,5 +542,115 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
self._w8a8_block_fp8_fused_moe(*params)
# For test
def torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_shape, out_dtype):
"""This function performs bmm with block-wise quantization using native torch."""
B, N, _ = w.shape
_, M, _ = a.shape
out = torch.empty((B, M, N), dtype=out_dtype, device=a.device)
for i in range(B):
out[i] = native_w8a8_block_fp8_matmul(
a[i], w[i], a_s[i], w_s[i], block_shape, output_dtype=out_dtype
)
return out
class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
DTYPES = [torch.bfloat16]
M = [1, 33, 64, 222, 8192]
N = [128, 512]
K = [128, 512]
BATCH = [128]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
try:
import deep_gemm
except ImportError:
raise unittest.SkipTest("DeepGEMM is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_batched_deep_gemm(self, M, N, K, B, block_size, dtype, seed):
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = torch.randn((B, M, K), dtype=torch.float32) / 10
a = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w_fp32 = (torch.rand((B, N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max
w = w_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w = (N + block_n - 1) // block_n
k_tiles_w = (K + block_k - 1) // block_k
w_s = (
torch.rand((B, n_tiles_w, k_tiles_w), dtype=torch.float32)
* factor_for_scale
)
a_s = torch.rand((B, M, k_tiles_w), dtype=torch.float32) * factor_for_scale
ae = a.new_empty(B, (M + 255) // 256 * 256, K)
ae_s = a_s.new_empty(B, (M + 255) // 256 * 256, k_tiles_w)
oe = torch.empty((B, (M + 255) // 256 * 256, N), dtype=dtype)
ae[:, :M, :] = a
ae_s[:, :M, :] = a_s
masked_m = torch.full((B,), M, dtype=torch.int)
expected_m = M
lhs = (
ae,
ae_s,
)
rhs = (
w,
w_s,
)
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
out = oe[:, :M, :]
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
< 0.0001
)
def test_w8a8_block_fp8_batched_deep_gemm(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.BATCH,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
B=params[3],
block_size=params[4],
dtype=params[5],
seed=params[6],
):
self._w8a8_block_fp8_batched_deep_gemm(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment