Unverified Commit 76196b3c authored by Ho-Ren (Jack) Chuang's avatar Ho-Ren (Jack) Chuang Committed by GitHub
Browse files

feat: Add FP4 (E2M1) KV Cache Support with Quantization Utilities for MLA (#10078)


Signed-off-by: default avatarHo-Ren (Jack) Chuang <horenchuang@bytedance.com>
Co-authored-by: default avatarYichen Wang <yichen.wang@bytedance.com>
parent 95191ebd
......@@ -23,7 +23,7 @@ from sglang.srt.layers.attention.utils import (
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_flashinfer_available
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_float4_e2m1fn_x2
from sglang.srt.utils.common import cached_triton_kernel
if is_flashinfer_available():
......@@ -361,19 +361,35 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
)
num_tokens_per_bs = max_num_tokens // max_bs
# Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
self.padded_q_buffer = torch.zeros(
(max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim),
dtype=self.data_type,
device=self.device,
)
if is_float4_e2m1fn_x2(self.data_type):
# Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
self.store_dtype = torch.uint8
self.padded_q_buffer = torch.zeros(
(max_bs, num_tokens_per_bs // 2, self.num_q_heads, self.kv_cache_dim),
dtype=self.store_dtype,
device=self.device,
)
# Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
self.unpad_output_buffer = torch.zeros(
(max_num_tokens, self.num_q_heads, 512),
dtype=self.data_type,
device=self.device,
)
# Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
self.unpad_output_buffer = torch.zeros(
(max_num_tokens // 2, self.num_q_heads, 512),
dtype=self.store_dtype,
device=self.device,
)
else:
# Buffer for padded query: (max_bs, max_draft_tokens, num_q_heads, v_head_dim)
self.padded_q_buffer = torch.zeros(
(max_bs, num_tokens_per_bs, self.num_q_heads, self.kv_cache_dim),
dtype=self.data_type,
device=self.device,
)
# Buffer for unpadded output: (max_num_tokens, num_q_heads, v_head_dim)
self.unpad_output_buffer = torch.zeros(
(max_num_tokens, self.num_q_heads, 512),
dtype=self.data_type,
device=self.device,
)
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
......
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
E2M1_MAX = 6.0
# Put constants directly on CUDA if available
_device = "cuda" if torch.cuda.is_available() else "cpu"
E2M1_VALUES = torch.tensor(
[0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.float32, device=_device
)
E2M1_BOUNDS = torch.tensor(
[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32, device=_device
)
class KVFP4QuantizeUtil:
"""Utility class for MXFP4 quantization and dequantization operations."""
@staticmethod
@torch.compile
def batched_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize tensor to KVFP4 format
Args:
tensor: Input tensor of shape [B, M, N]
Returns:
quant_tensor: Quantized tensor of shape [B, M, N/2]
scale_factors: Scale factors of shape [B, M*N/16]
"""
b, m, n = tensor.shape
# Reshape to [B, M*N/16, 16] for block-wise quantization
reshaped = tensor.view(b, m * n // 16, 16)
# Compute scale factors per block
block_max = reshaped.abs().max(dim=-1, keepdim=True).values
scale_exp = torch.ceil(torch.log2(torch.clamp(block_max / E2M1_MAX, min=1e-10)))
scale_factors = (scale_exp + 127).squeeze(-1).to(torch.uint8)
# Apply scaling
scaled = reshaped / torch.exp2(scale_exp)
# Quantize to FP4
sign_bits = (scaled < 0).to(torch.uint8) << 3
abs_vals = scaled.abs()
# Pure tensor version (CUDA Graph safe)
magnitude_bits = torch.sum(abs_vals.unsqueeze(-1) >= E2M1_BOUNDS, dim=-1)
# Combine sign and magnitude
fp4_vals = sign_bits + magnitude_bits.to(torch.uint8)
# Pack two FP4 values into one uint8
fp4_reshaped = fp4_vals.view(b, m, n)
packed = (fp4_reshaped[..., 1::2] << 4) + fp4_reshaped[..., 0::2]
return packed, scale_factors
@staticmethod
@torch.compile
def batched_dequantize(
quant_tensor: torch.Tensor,
scale_factors: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""
Dequantize KVFP4 tensor
Args:
quant_tensor: Quantized tensor of shape [B, M, N/2]
scale_factors: Scale factors of shape [B, M*N/16]
dtype: Target dtype for output
Returns:
Dequantized tensor of shape [B, M, N]
"""
b, m, n_half = quant_tensor.shape
n = n_half * 2
# More efficient unpacking using bit operations
fp4_vals = torch.empty(b, m, n, dtype=torch.uint8, device=quant_tensor.device)
fp4_vals[..., 0::2] = quant_tensor & 0x0F
fp4_vals[..., 1::2] = (quant_tensor >> 4) & 0x0F
# Extract sign and magnitude
sign_mask = (fp4_vals & 0x08) != 0
magnitude_idx = fp4_vals & 0x07
# Convert to float values
float_vals = E2M1_VALUES[magnitude_idx.long()]
float_vals = torch.where(sign_mask, -float_vals, float_vals)
# Reshape for block-wise scaling
reshaped = float_vals.view(b, m * n // 16, 16)
# Apply scale factors
scale_exp = scale_factors.float() - 127
scaled = reshaped * torch.exp2(scale_exp.unsqueeze(-1))
return scaled.view(b, m, n).to(dtype)
......@@ -43,7 +43,18 @@ import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
from sglang.srt.mem_cache.utils import (
get_mla_kv_buffer_triton,
set_mla_kv_buffer_triton,
set_mla_kv_scale_buffer_triton,
)
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
is_float4_e2m1fn_x2,
is_npu,
next_power_of_2,
)
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
......@@ -1261,131 +1272,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
)
@triton.jit
def set_mla_kv_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
BLOCK: tl.constexpr,
):
pid_loc = tl.program_id(0)
pid_blk = tl.program_id(1)
base = pid_blk * BLOCK
offs = base + tl.arange(0, BLOCK)
total_dim = nope_dim + rope_dim
mask = offs < total_dim
loc = tl.load(loc_ptr + pid_loc)
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
if base + BLOCK <= nope_dim:
src = tl.load(
cache_k_nope_ptr + pid_loc * nope_stride + offs,
mask=mask,
)
else:
offs_rope = offs - nope_dim
src = tl.load(
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
mask=mask,
)
tl.store(dst_ptr, src, mask=mask)
def set_mla_kv_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
nope_dim = cache_k_nope.shape[-1]
rope_dim = cache_k_rope.shape[-1]
total_dim = nope_dim + rope_dim
BLOCK = 128
n_loc = loc.numel()
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
set_mla_kv_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
BLOCK=BLOCK,
)
@triton.jit
def get_mla_kv_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
):
pid_loc = tl.program_id(0)
loc = tl.load(loc_ptr + pid_loc)
loc_src_ptr = kv_buffer_ptr + loc * buffer_stride
nope_offs = tl.arange(0, nope_dim)
nope_src_ptr = loc_src_ptr + nope_offs
nope_src = tl.load(nope_src_ptr)
tl.store(
cache_k_nope_ptr + pid_loc * nope_stride + nope_offs,
nope_src,
)
rope_offs = tl.arange(0, rope_dim)
rope_src_ptr = loc_src_ptr + nope_dim + rope_offs
rope_src = tl.load(rope_src_ptr)
tl.store(
cache_k_rope_ptr + pid_loc * rope_stride + rope_offs,
rope_src,
)
def get_mla_kv_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
# The source data type will be implicitly converted to the target data type.
nope_dim = cache_k_nope.shape[-1] # 512
rope_dim = cache_k_rope.shape[-1] # 64
n_loc = loc.numel()
grid = (n_loc,)
get_mla_kv_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
)
class MLATokenToKVPool(KVCache):
def __init__(
self,
......@@ -1430,15 +1316,41 @@ class MLATokenToKVPool(KVCache):
if self.custom_mem_pool
else nullcontext()
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, self.kv_cache_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
if is_float4_e2m1fn_x2(self.dtype):
m = size + page_size
n = 1 # head_num
k = self.kv_cache_dim # head_dim
scale_block_size = 16
self.store_dtype = torch.uint8
self.kv_buffer = [
torch.zeros(
(m, n, k // 2),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
self.kv_scale_buffer = [
torch.zeros(
(m, k // scale_block_size),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
else:
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, self.kv_cache_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.kv_buffer],
......@@ -1471,7 +1383,23 @@ class MLATokenToKVPool(KVCache):
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
if is_float4_e2m1fn_x2(self.dtype):
cache_k_nope_fp4 = self.kv_buffer[layer_id - self.start_layer].view(
torch.uint8
)
cache_k_nope_fp4_sf = self.kv_scale_buffer[layer_id - self.start_layer]
from sglang.srt.layers.quantization.kvfp4_tensor import (
KVFP4QuantizeUtil,
)
cache_k_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize(
cache_k_nope_fp4, cache_k_nope_fp4_sf
)
return cache_k_nope_fp4_dequant
else:
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
......@@ -1497,11 +1425,29 @@ class MLATokenToKVPool(KVCache):
layer_id = layer.layer_id
assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if is_float4_e2m1fn_x2(self.dtype):
from sglang.srt.layers.quantization.kvfp4_tensor import (
KVFP4QuantizeUtil,
)
cache_k_fp4, cache_k_fp4_sf = KVFP4QuantizeUtil.batched_quantize(
cache_k
)
else:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
self.store_dtype
)
if is_float4_e2m1fn_x2(self.dtype):
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k_fp4.view(
self.store_dtype
)
self.kv_scale_buffer[layer_id - self.start_layer][loc] = (
cache_k_fp4_sf.view(self.store_dtype)
)
else:
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
self.store_dtype
)
else:
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
......@@ -1523,18 +1469,44 @@ class MLATokenToKVPool(KVCache):
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
else:
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if is_float4_e2m1fn_x2(self.dtype):
from sglang.srt.layers.quantization.kvfp4_tensor import (
KVFP4QuantizeUtil,
)
cache_k_nope_fp4, cache_k_nope_fp4_sf = (
KVFP4QuantizeUtil.batched_quantize(cache_k_nope)
)
cache_k_rope_fp4, cache_k_rope_fp4_sf = (
KVFP4QuantizeUtil.batched_quantize(cache_k_rope)
)
else:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer],
loc,
cache_k_nope,
cache_k_rope,
)
if is_float4_e2m1fn_x2(self.dtype):
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer],
loc,
cache_k_nope_fp4,
cache_k_rope_fp4,
)
set_mla_kv_scale_buffer_triton(
self.kv_scale_buffer[layer_id - self.start_layer],
loc,
cache_k_nope_fp4_sf,
cache_k_rope_fp4_sf,
)
else:
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer],
loc,
cache_k_nope,
cache_k_rope,
)
def get_mla_kv_buffer(
self,
......
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common utilities."""
import torch
import triton
import triton.language as tl
@triton.jit
def set_mla_kv_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
BLOCK: tl.constexpr,
):
pid_loc = tl.program_id(0)
pid_blk = tl.program_id(1)
base = pid_blk * BLOCK
offs = base + tl.arange(0, BLOCK)
total_dim = nope_dim + rope_dim
mask = offs < total_dim
loc = tl.load(loc_ptr + pid_loc)
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
if base + BLOCK <= nope_dim:
src = tl.load(
cache_k_nope_ptr + pid_loc * nope_stride + offs,
mask=mask,
)
else:
offs_rope = offs - nope_dim
src = tl.load(
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
mask=mask,
)
tl.store(dst_ptr, src, mask=mask)
def set_mla_kv_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
nope_dim = cache_k_nope.shape[-1]
rope_dim = cache_k_rope.shape[-1]
total_dim = nope_dim + rope_dim
BLOCK = 128
n_loc = loc.numel()
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
set_mla_kv_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
BLOCK=BLOCK,
)
@triton.jit
def set_mla_kv_scale_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
BLOCK: tl.constexpr,
):
pid_loc = tl.program_id(0)
pid_blk = tl.program_id(1)
base = pid_blk * BLOCK
offs = base + tl.arange(0, BLOCK)
total_dim = nope_dim + rope_dim
mask = offs < total_dim # Make sure don't cross the boundary
loc = tl.load(loc_ptr + pid_loc)
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
# Check each offs should read 'nope' or 'rope'
is_nope = offs < nope_dim
src_nope = tl.load(
cache_k_nope_ptr + pid_loc * nope_stride + offs, mask=mask & is_nope, other=0.0
)
src_rope = tl.load(
cache_k_rope_ptr + pid_loc * rope_stride + (offs - nope_dim),
mask=mask & ~is_nope,
other=0.0,
)
# Combine nope + rope
src = src_nope + src_rope
tl.store(dst_ptr, src, mask=mask)
def set_mla_kv_scale_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
nope_dim = cache_k_nope.shape[-1]
rope_dim = cache_k_rope.shape[-1]
total_dim = nope_dim + rope_dim
BLOCK = 128 # Keep origin, works for smaller total_dim as well.
n_loc = loc.numel()
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
set_mla_kv_scale_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
BLOCK=BLOCK,
)
@triton.jit
def get_mla_kv_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
):
pid_loc = tl.program_id(0)
loc = tl.load(loc_ptr + pid_loc)
loc_src_ptr = kv_buffer_ptr + loc * buffer_stride
nope_offs = tl.arange(0, nope_dim)
nope_src_ptr = loc_src_ptr + nope_offs
nope_src = tl.load(nope_src_ptr)
tl.store(
cache_k_nope_ptr + pid_loc * nope_stride + nope_offs,
nope_src,
)
rope_offs = tl.arange(0, rope_dim)
rope_src_ptr = loc_src_ptr + nope_dim + rope_offs
rope_src = tl.load(rope_src_ptr)
tl.store(
cache_k_rope_ptr + pid_loc * rope_stride + rope_offs,
rope_src,
)
def get_mla_kv_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
# The source data type will be implicitly converted to the target data type.
nope_dim = cache_k_nope.shape[-1] # 512
rope_dim = cache_k_rope.shape[-1] # 64
n_loc = loc.numel()
grid = (n_loc,)
get_mla_kv_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
)
......@@ -138,6 +138,8 @@ from sglang.srt.utils import (
get_bool_env_var,
get_cpu_ids_by_node,
init_custom_process_group,
is_cuda,
is_float4_e2m1fn_x2,
is_hip,
is_npu,
log_info_on_rank0,
......@@ -195,6 +197,7 @@ def add_chunked_prefix_cache_attention_backend(backend_name):
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
......@@ -1273,6 +1276,21 @@ class ModelRunner:
* num_layers
* torch._utils._element_size(self.kv_cache_dtype)
)
if is_float4_e2m1fn_x2(self.kv_cache_dtype):
# kv_scale_buffer
scale_block_size = 16
cell_size = (cell_size // 2) + (
(
(
self.model_config.kv_lora_rank
+ self.model_config.qk_rope_head_dim
)
// scale_block_size
)
* num_layers
* torch._utils._element_size(self.kv_cache_dtype)
)
# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
if is_deepseek_nsa(self.model_config.hf_config):
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
......@@ -1509,6 +1527,15 @@ class ModelRunner:
self.kv_cache_dtype = torch.float8_e4m3fn
elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
self.kv_cache_dtype = torch.bfloat16
elif self.server_args.kv_cache_dtype == "fp4_e2m1":
if hasattr(torch, "float4_e2m1fn_x2"):
self.kv_cache_dtype = torch.float4_e2m1fn_x2
logger.warning(f"FP4 (E2M1) KV Cache might lead to a accuracy drop!")
else:
logger.warning(
f"--kv-cache-dtype falls back to 'auto' because this torch version does not support torch.float4_e2m1fn_x2"
)
self.kv_cache_dtype = self.dtype
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
......
......@@ -1204,9 +1204,9 @@ class ServerArgs:
)
self.page_size = 64
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
if self.kv_cache_dtype not in ["fp8_e4m3", "fp4_e2m1", "auto"]:
raise ValueError(
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3, fp4_e2m1, or auto."
)
if (
......@@ -1987,8 +1987,8 @@ class ServerArgs:
"--kv-cache-dtype",
type=str,
default=ServerArgs.kv_cache_dtype,
choices=["auto", "fp8_e5m2", "fp8_e4m3", "bf16", "bfloat16"],
help='Data type for kv cache storage. "auto" will use model data type. "bf16" or "bfloat16" for BF16 KV cache. "fp8_e5m2" and "fp8_e4m3" are supported for CUDA 11.8+.',
choices=["auto", "fp8_e5m2", "fp8_e4m3", "bf16", "bfloat16", "fp4_e2m1"],
help='Data type for kv cache storage. "auto" will use model data type. "bf16" or "bfloat16" for BF16 KV cache. "fp8_e5m2" and "fp8_e4m3" are supported for CUDA 11.8+. "fp4_e2m1" (only mxfp4) is supported for CUDA 12.8+ and PyTorch 2.8.0+',
)
parser.add_argument(
"--enable-fp32-lm-head",
......
......@@ -152,6 +152,12 @@ def is_cpu() -> bool:
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
def is_float4_e2m1fn_x2(dtype) -> bool:
"""Check if dtype is float4_e2m1fn_x2 and CUDA is available."""
target_dtype = getattr(torch, "float4_e2m1fn_x2", None)
return is_cuda() and dtype == target_dtype
def get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))
......
#!/usr/bin/env python3
import time
import numpy as np
import pytest
import torch
from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil
def calculate_accuracy_metrics(
original: torch.Tensor, reconstructed: torch.Tensor
) -> dict[str, float]:
"""Calculate accuracy metrics between original and reconstructed tensors."""
mse = torch.mean((original - reconstructed) ** 2).item()
mae = torch.mean(torch.abs(original - reconstructed)).item()
# PSNR calculation
max_val = torch.max(torch.abs(original)).item()
psnr = 20 * np.log10(max_val / np.sqrt(mse)) if mse > 0 else float("inf")
# Relative error
rel_error = torch.mean(
torch.abs(original - reconstructed) / (torch.abs(original) + 1e-8)
).item()
return {"MSE": mse, "MAE": mae, "PSNR": psnr, "Relative Error": rel_error}
def run_benchmark(m, n, k, num_runs=100) -> dict[str, dict[str, float]]:
"""Run FP8 vs KVFP4 quantization benchmark and return metrics."""
tensor_bf16 = torch.randn(m, n, k, dtype=torch.bfloat16, device="cuda")
# --- FP8 ---
for _ in range(3): # warmup
_ = tensor_bf16 * 2
torch.cuda.synchronize()
start = time.time()
for _ in range(num_runs):
tensor_fp8 = tensor_bf16.to(torch.float8_e4m3fn)
torch.cuda.synchronize()
fp8_quant_time = (time.time() - start) / num_runs
start = time.time()
for _ in range(num_runs):
tensor_fp8_dequant = tensor_fp8.to(torch.bfloat16)
torch.cuda.synchronize()
fp8_dequant_time = (time.time() - start) / num_runs
fp8_metrics = calculate_accuracy_metrics(tensor_bf16, tensor_fp8_dequant)
# --- KVFP4 ---
tensor_fp4, scale_factors = KVFP4QuantizeUtil.batched_quantize(tensor_bf16)
_ = KVFP4QuantizeUtil.batched_dequantize(tensor_fp4, scale_factors)
start = time.time()
for _ in range(num_runs):
tensor_fp4, scale_factors = KVFP4QuantizeUtil.batched_quantize(tensor_bf16)
torch.cuda.synchronize()
fp4_quant_time = (time.time() - start) / num_runs
start = time.time()
for _ in range(num_runs):
tensor_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize(
tensor_fp4, scale_factors
)
torch.cuda.synchronize()
fp4_dequant_time = (time.time() - start) / num_runs
fp4_metrics = calculate_accuracy_metrics(tensor_bf16, tensor_fp4_dequant)
return {
"fp8": {
"quant_time": fp8_quant_time,
"dequant_time": fp8_dequant_time,
**fp8_metrics,
},
"fp4": {
"quant_time": fp4_quant_time,
"dequant_time": fp4_dequant_time,
**fp4_metrics,
},
}
# default tensor shapes (m, n, k)
# [M, 1, 576]: DeepSeekR1-FP4 MLA
# [M, 8, 64]: gpt-oss-20b MHA
MNK_FACTORS = [
(64, 1, 576),
(512, 1, 576),
(1024, 1, 576),
(4096, 1, 576),
(2868672, 1, 576),
(64, 8, 64),
(512, 8, 64),
(1024, 8, 64),
(4096, 8, 64),
(2868672, 8, 64),
]
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
def test_kvfp4_quant_dequant(m, n, k):
"""Benchmark FP8 vs KVFP4 for predefined tensor shapes."""
print(f"\n=== Running benchmark for tensor shape: [{m}, {n}, {k}] ===")
results = run_benchmark(m, n, k)
print("FP8:", results["fp8"])
print("FP4:", results["fp4"])
# Basic assertions to make sure metrics are reasonable
assert results["fp4"]["MSE"] < 1.0
assert results["fp8"]["MSE"] < 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment