Unverified Commit bdde2375 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[perf] experimental enhance fp8 per-tensor quant (#5370)

parent e9fc2ac7
...@@ -839,3 +839,103 @@ def w8a8_block_fp8_matmul( ...@@ -839,3 +839,103 @@ def w8a8_block_fp8_matmul(
) )
return C return C
@triton.jit
def _per_tensor_quant_mla_fp8_stage1(
x_ptr,
x_s_ptr,
head_size,
x_stride_h,
x_stride_s,
eps,
fp8_max,
BLOCK_SIZE: tl.constexpr,
):
seq_id = tl.program_id(0)
head_id = tl.program_id(1)
offset = tl.arange(0, BLOCK_SIZE)
mask = offset < head_size
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
_absmax = tl.maximum(tl.max(tl.abs(x)), eps)
tl.atomic_max(x_s_ptr, _absmax / fp8_max)
@triton.jit
def _per_tensor_quant_mla_fp8_stage2(
x_ptr,
x_s_ptr,
x_q_ptr,
num_seq,
head_size,
x_stride_h,
x_stride_s,
fp8_min,
fp8_max,
BLOCK_SIZE: tl.constexpr,
):
seq_id = tl.program_id(0)
head_id = tl.program_id(1)
offset = tl.arange(0, BLOCK_SIZE)
mask = offset < head_size
x_s = tl.load(x_s_ptr)
x_s_inv = 1.0 / x_s
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
tl.store(x_q_ptr + offset, x_q, mask=mask)
def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case.
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
x_q = x.new_empty(x.size(), dtype=dtype)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
num_head, num_seq, head_size = x.shape
BLOCK_SIZE = triton.next_power_of_2(head_size)
grid = (num_seq, num_head)
_per_tensor_quant_mla_fp8_stage1[grid](
x,
x_s,
head_size,
x.stride(0),
x.stride(1),
eps,
fp8_max,
BLOCK_SIZE,
)
_per_tensor_quant_mla_fp8_stage2[grid](
x,
x_s,
x_q,
num_seq,
head_size,
x.stride(0),
x.stride(1),
-fp8_max,
fp8_max,
BLOCK_SIZE,
)
return x_q, x_s
...@@ -168,13 +168,13 @@ def input_to_float8( ...@@ -168,13 +168,13 @@ def input_to_float8(
"""This function quantizes input values to float8 values with tensor-wise quantization.""" """This function quantizes input values to float8 values with tensor-wise quantization."""
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax() min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
fp8_max = finfo.max fp8_max = finfo.max
if _is_hip: if _is_hip:
dtype = torch.float8_e4m3fnuz dtype = torch.float8_e4m3fnuz
fp8_max = 224.0 fp8_max = 224.0
scale = fp8_max / amax scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max) x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
...@@ -213,7 +213,11 @@ def block_quant_to_tensor_quant( ...@@ -213,7 +213,11 @@ def block_quant_to_tensor_quant(
for j in range(n_tiles): for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_block)
if _is_cuda
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
)
return x_q_tensor, scale return x_q_tensor, scale
...@@ -222,7 +226,11 @@ def channel_quant_to_tensor_quant( ...@@ -222,7 +226,11 @@ def channel_quant_to_tensor_quant(
x_s: torch.Tensor, x_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
x_dq_channel = x_q_channel.to(torch.float32) * x_s x_dq_channel = x_q_channel.to(torch.float32) * x_s
x_q_tensor, scale = input_to_float8(x_dq_channel, dtype=x_q_channel.dtype) x_q_tensor, scale = (
sgl_scaled_fp8_quant(x_dq_channel)
if _is_cuda
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
)
return x_q_tensor, scale return x_q_tensor, scale
......
...@@ -53,10 +53,10 @@ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE ...@@ -53,10 +53,10 @@ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant, block_quant_to_tensor_quant,
channel_quant_to_tensor_quant, channel_quant_to_tensor_quant,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.int8_utils import ( from sglang.srt.layers.quantization.int8_utils import (
...@@ -817,8 +817,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -817,8 +817,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_kc.to(torch.bfloat16) * self.w_scale, self.w_kc.to(torch.bfloat16) * self.w_scale,
) )
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = input_to_float8( q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), torch.float8_e4m3fn q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
) )
q_nope_out = bmm_fp8( q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
...@@ -848,8 +848,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -848,8 +848,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc.to(torch.bfloat16) * self.w_scale, self.w_vc.to(torch.bfloat16) * self.w_scale,
) )
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = input_to_float8( attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), torch.float8_e4m3fn attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
) )
attn_bmm_output = bmm_fp8( attn_bmm_output = bmm_fp8(
attn_output_val, attn_output_val,
...@@ -895,8 +895,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -895,8 +895,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_kc.to(torch.bfloat16) * self.w_scale, self.w_kc.to(torch.bfloat16) * self.w_scale,
) )
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = input_to_float8( q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), torch.float8_e4m3fn q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
) )
q_nope_out = bmm_fp8( q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
...@@ -991,8 +991,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -991,8 +991,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc.to(torch.bfloat16) * self.w_scale, self.w_vc.to(torch.bfloat16) * self.w_scale,
) )
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = input_to_float8( attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), torch.float8_e4m3fn attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
) )
attn_bmm_output = bmm_fp8( attn_bmm_output = bmm_fp8(
attn_output_val, attn_output_val,
......
...@@ -7,10 +7,12 @@ import torch ...@@ -7,10 +7,12 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_fp8,
per_token_group_quant_fp8, per_token_group_quant_fp8,
static_quant_fp8, static_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
) )
from sglang.srt.layers.quantization.fp8_utils import input_to_float8
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
_is_cuda = torch.cuda.is_available() and torch.version.cuda _is_cuda = torch.cuda.is_available() and torch.version.cuda
...@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase): ...@@ -155,6 +157,61 @@ class TestStaticQuantFP8(CustomTestCase):
self._static_quant_fp8(*params) self._static_quant_fp8(*params)
class TestPerTensorQuantMlaFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
LAST_D_EXT = [1024, 0]
LAST_D = [512]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _per_tensor_quant_mla_fp8(self, num_tokens, d, last_d_ext, last_d, dtype, seed):
torch.manual_seed(seed)
x = torch.rand(
(num_tokens, d // last_d, last_d + last_d_ext),
dtype=dtype,
)
x_sub, _ = x.split([last_d, last_d_ext], dim=-1)
with torch.inference_mode():
ref_out, ref_s = input_to_float8(x_sub.transpose(0, 1))
out, out_s = per_tensor_quant_mla_fp8(x_sub.transpose(0, 1))
self.assertTrue(out.is_contiguous())
self.assertTrue(
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
)
self.assertTrue(
torch.allclose(out_s.to(torch.float32), ref_s.to(torch.float32))
)
def test_per_tensor_quant_mla_fp8(self):
for params in itertools.product(
self.NUM_TOKENS,
self.D,
self.LAST_D_EXT,
self.LAST_D,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
d=params[1],
last_d_ext=params[2],
last_d=params[3],
dtype=params[4],
seed=params[5],
):
self._per_tensor_quant_mla_fp8(*params)
# For test # For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch. """This function performs matrix multiplication with block-wise quantization using native torch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment