"docs/vscode:/vscode.git/clone" did not exist on "709c9f1f257fd15545ad19b89ed5019cb5ea338b"
Unverified Commit 6b2b7bd0 authored by sychen52's avatar sychen52 Committed by GitHub
Browse files

Add nvfp4 support to reshape_and_cache_flash (#37332)


Signed-off-by: default avatarShiyang Chen <shiychen@nvidia.com>
parent 70770268
...@@ -923,6 +923,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -923,6 +923,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}") CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
# nvfp4_kv_cache_kernels uses non-stable torch API and is called directly
# from cache_kernels.cu, so it belongs in _C rather than _C_stable.
set(NVFP4_KV_SRC "csrc/nvfp4_kv_cache_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${NVFP4_KV_SRC}"
CUDA_ARCHS "${FP4_ARCHS}")
target_sources(_C PRIVATE ${NVFP4_KV_SRC})
target_compile_definitions(_C PRIVATE ENABLE_NVFP4_SM120=1)
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
...@@ -949,6 +957,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -949,6 +957,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}") CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
set(NVFP4_KV_SRC "csrc/nvfp4_kv_cache_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${NVFP4_KV_SRC}"
CUDA_ARCHS "${FP4_ARCHS}")
target_sources(_C PRIVATE ${NVFP4_KV_SRC})
target_compile_definitions(_C PRIVATE ENABLE_NVFP4_SM100=1)
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
......
...@@ -724,6 +724,28 @@ void reshape_and_cache_flash( ...@@ -724,6 +724,28 @@ void reshape_and_cache_flash(
int num_tokens = slot_mapping.size(0); int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "nvfp4") {
#if defined(ENABLE_NVFP4_SM100) || defined(ENABLE_NVFP4_SM120)
// NVFP4 dispatch is compiled separately for SM100+.
extern void reshape_and_cache_nvfp4_dispatch(
torch::Tensor & key, torch::Tensor & value, torch::Tensor & key_cache,
torch::Tensor & value_cache, torch::Tensor & slot_mapping,
torch::Tensor & k_scale, torch::Tensor & v_scale);
reshape_and_cache_nvfp4_dispatch(key, value, key_cache, value_cache,
slot_mapping, k_scale, v_scale);
return;
#else
TORCH_CHECK(false,
"NVFP4 KV cache requires SM100+ (Blackwell). "
"Please rebuild vllm with a Blackwell-compatible CUDA target.");
#endif
}
// Original FP8/auto path.
int block_size = key_cache.size(1); int block_size = key_cache.size(1);
int64_t key_stride = key.stride(0); int64_t key_stride = key.stride(0);
...@@ -741,8 +763,6 @@ void reshape_and_cache_flash( ...@@ -741,8 +763,6 @@ void reshape_and_cache_flash(
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_FLASH); CALL_RESHAPE_AND_CACHE_FLASH);
......
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// NVFP4 KV cache store kernel.
// Quantizes bf16 key/value to packed FP4 + FP8 block scales and writes them
// into the paged KV cache.
//
// Per page layout: [K_data | K_scale | V_data | V_scale]
// Both data and scale regions are contiguous per head, enabling direct
// TMA descriptor use.
//
// Reuses device functions from nvfp4_utils.cuh:
// - cvt_warp_fp16_to_fp4() for bf16 → fp4 quantization + block scale
// - pack_fp4() for packing float pairs to fp4
// - reciprocal_approximate_ftz() for fast reciprocal
#define NVFP4_ENABLE_ELTS16 1
#include "libtorch_stable/quantization/fp4/nvfp4_utils.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "dispatch_utils.h"
namespace vllm {
// Compute swizzled scale offset for SM100 trtllm-gen MHA kernel.
// The swizzle pattern for HND layout is:
// [T//4, 4, 4, S//4] → permute(0, 2, 3, 1) → reshape to [T, S]
// where T = block_size (page_size), S = scale_dim = head_size // 16.
//
// For a linear (t, s) position, the swizzled position is:
// swizzled_t = (t / 4) * 4 + (s / (S / 4))
// swizzled_s = (s % (S / 4)) * 4 + (t % 4)
__device__ __forceinline__ int swizzle_scale_offset(int t, int s,
int scale_dim) {
int s_group = scale_dim / 4;
int swizzled_t = (t / 4) * 4 + (s / s_group);
int swizzled_s = (s % s_group) * 4 + (t % 4);
return swizzled_t * scale_dim + swizzled_s;
}
// Kernel: quantize bf16 key/value to NVFP4 and store in paged KV cache.
//
// Takes separate data and scale cache pointers for K and V.
// Within each KV side, data and scale are separate contiguous regions.
//
// Threading: one CUDA block per token, threads process heads and
// groups of 16 elements within each head.
template <typename scalar_t>
__global__ void reshape_and_cache_nvfp4_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
uint8_t* __restrict__ key_data_cache, // data region for K
uint8_t* __restrict__ value_data_cache, // data region for V
uint8_t* __restrict__ key_scale_cache, // scale region for K
uint8_t* __restrict__ value_scale_cache, // scale region for V
const int64_t* __restrict__ slot_mapping, // [num_actual_tokens]
const float* __restrict__ k_scale_ptr, // pointer to checkpoint k_scale
const float* __restrict__ v_scale_ptr, // pointer to checkpoint v_scale
const int64_t key_stride, // key.stride(0) in elements
const int64_t value_stride, // value.stride(0) in elements
const int num_heads, const int head_size, const int block_size,
const int64_t data_block_stride, // data cache stride for dim 0
const int64_t data_head_stride, // data cache stride for heads
const int64_t data_block_offset_stride, // data cache stride for tokens
const int64_t scale_block_stride, // scale cache stride for dim 0
const int64_t scale_head_stride, // scale cache stride for heads
const int64_t scale_block_offset_stride // scale cache stride for tokens
) {
using CudaType = typename CUDATypeConverter<scalar_t>::Type;
using PVec = PackedVec<CudaType, CVT_FP4_PACK16>;
static constexpr int ELTS = CVT_FP4_ELTS_PER_THREAD; // 16 or 8
static constexpr int THREADS_PER_SF = CVT_FP4_SF_VEC_SIZE / ELTS;
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) return;
const int64_t block_idx = slot_idx / block_size;
const int block_offset = static_cast<int>(slot_idx % block_size);
const int scale_dim = head_size / 16;
const int groups_per_head = head_size / CVT_FP4_SF_VEC_SIZE;
const int total_groups = num_heads * groups_per_head;
const int tid = threadIdx.x;
const int num_thread_groups = blockDim.x / THREADS_PER_SF;
const int tg_id = tid / THREADS_PER_SF;
const int tg_lane = tid % THREADS_PER_SF;
// Process both K (kv=0) and V (kv=1)
#pragma unroll
for (int kv = 0; kv < 2; kv++) {
const scalar_t* __restrict__ src = (kv == 0) ? key : value;
const float global_scale = 1.0f / ((kv == 0) ? *k_scale_ptr : *v_scale_ptr);
const int64_t src_stride = (kv == 0) ? key_stride : value_stride;
uint8_t* __restrict__ data_cache =
(kv == 0) ? key_data_cache : value_data_cache;
uint8_t* __restrict__ sc_cache =
(kv == 0) ? key_scale_cache : value_scale_cache;
// Source pointer for this token (use actual stride, not assumed contiguous)
const CudaType* __restrict__ token_src =
reinterpret_cast<const CudaType*>(src) + token_idx * src_stride;
// Destination bases in data and scale caches for this token's block
uint8_t* __restrict__ data_block =
data_cache + block_idx * data_block_stride;
uint8_t* __restrict__ scale_block =
sc_cache + block_idx * scale_block_stride;
for (int g = tg_id; g < total_groups; g += num_thread_groups) {
const int head = g / groups_per_head;
const int group_in_head = g % groups_per_head;
// Load 16 (or 8) bf16 elements from source
PVec in_vec;
const CudaType* __restrict__ src_ptr =
token_src + head * head_size + group_in_head * CVT_FP4_SF_VEC_SIZE +
tg_lane * ELTS;
#pragma unroll
for (int i = 0; i < ELTS / 2; i++) {
in_vec.elts[i] = reinterpret_cast<
const typename PackedTypeConverter<CudaType>::Type*>(src_ptr)[i];
}
// Quantize: produces packed fp4 and writes scale factor.
uint8_t sf_val;
uint8_t* sf_out_ptr = (tg_lane == 0) ? &sf_val : nullptr;
fp4_packed_t packed = cvt_warp_fp16_to_fp4<CudaType, THREADS_PER_SF>(
in_vec, global_scale, sf_out_ptr);
// Write packed FP4 data to data cache
uint8_t* __restrict__ data_dst = data_block + head * data_head_stride +
block_offset * data_block_offset_stride;
#if CVT_FP4_PACK16
{
// 16 elements → 8 bytes (u32x2)
int data_byte_offset = group_in_head * 8;
reinterpret_cast<uint64_t*>(data_dst + data_byte_offset)[0] =
(uint64_t(packed.hi) << 32) | uint64_t(packed.lo);
}
#else
{
// 8 elements → 4 bytes (uint32_t)
int data_byte_offset =
group_in_head * CVT_FP4_SF_VEC_SIZE / 2 + tg_lane * ELTS / 2;
reinterpret_cast<uint32_t*>(data_dst + data_byte_offset)[0] = packed;
}
#endif
// Write block scale to scale cache.
// K (kv==0): linear layout (no swizzle).
// V (kv==1): swizzled layout for SM100 trtllm-gen MHA kernel.
if (sf_out_ptr != nullptr) {
int scale_idx = group_in_head;
uint8_t* __restrict__ scale_dst;
if (kv == 0) {
scale_dst = scale_block + head * scale_head_stride +
block_offset * scale_block_offset_stride + scale_idx;
} else {
int swizzled_offset =
swizzle_scale_offset(block_offset, scale_idx, scale_dim);
int swizzled_t = swizzled_offset / scale_dim;
int swizzled_s = swizzled_offset % scale_dim;
scale_dst = scale_block + head * scale_head_stride +
swizzled_t * scale_block_offset_stride + swizzled_s;
}
*scale_dst = sf_val;
}
}
}
}
} // namespace vllm
// Non-template entry point callable from cache_kernels.cu.
// Receives key_cache/value_cache as kv_cache[:, 0] and kv_cache[:, 1].
// Each KV side contains both data and scale:
// page = [K_data | K_scale | V_data | V_scale]
void reshape_and_cache_nvfp4_dispatch(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int data_dim = head_size / 2;
int scale_dim = head_size / 16;
int full_dim = data_dim + scale_dim;
// key_cache is kv_cache[:, 0] with shape
// [num_blocks, block_size, num_heads, full_dim] in logical order.
// Strides encode the physical layout (HND or NHD).
TORCH_CHECK(key_cache.dim() == 4, "key_cache must be 4D");
TORCH_CHECK(key_cache.size(3) == full_dim,
"key_cache last dim must be data_dim + scale_dim, got ",
key_cache.size(3), " expected ", full_dim);
int block_size = key_cache.size(1);
TORCH_CHECK(head_size % 16 == 0,
"head_size must be divisible by 16 for NVFP4 KV cache");
TORCH_CHECK(block_size % 4 == 0,
"block_size must be divisible by 4 for NVFP4 KV cache swizzle");
// Detect physical layout from strides (based on full_dim).
// HND: head stride > block_offset stride.
bool is_hnd = key_cache.stride(2) > key_cache.stride(1);
int64_t data_block_stride = key_cache.stride(0); // page_bytes
int64_t data_head_stride, data_block_offset_stride;
if (is_hnd) {
data_head_stride = (int64_t)block_size * data_dim;
data_block_offset_stride = data_dim;
} else {
data_head_stride = data_dim;
data_block_offset_stride = (int64_t)num_heads * data_dim;
}
// Page layout: [K_data | K_scale | V_data | V_scale]
// Scale follows data within each KV side.
int64_t data_per_kv = (int64_t)num_heads * block_size * data_dim;
uint8_t* key_scale_ptr = key_cache.data_ptr<uint8_t>() + data_per_kv;
uint8_t* value_scale_ptr = value_cache.data_ptr<uint8_t>() + data_per_kv;
// Scale strides: same page stride, inner strides from layout.
int64_t scale_block_stride = data_block_stride;
int64_t scale_head_stride, scale_block_offset_stride;
if (is_hnd) {
scale_head_stride = (int64_t)block_size * scale_dim;
scale_block_offset_stride = scale_dim;
} else {
scale_head_stride = scale_dim;
scale_block_offset_stride = (int64_t)num_heads * scale_dim;
}
const float* k_scale_ptr = k_scale.data_ptr<float>();
const float* v_scale_ptr = v_scale.data_ptr<float>();
int groups_per_head = head_size / CVT_FP4_SF_VEC_SIZE;
int total_groups = num_heads * groups_per_head;
constexpr int THREADS_PER_SF = CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
int num_threads = std::min(total_groups * THREADS_PER_SF, 512);
num_threads = ((num_threads + 31) / 32) * 32;
dim3 grid(num_tokens);
dim3 block(num_threads);
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_REDUCED_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_nvfp4", [&] {
vllm::reshape_and_cache_nvfp4_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<uint8_t>(), value_cache.data_ptr<uint8_t>(),
key_scale_ptr, value_scale_ptr,
slot_mapping.data_ptr<int64_t>(), k_scale_ptr, v_scale_ptr,
key.stride(0), value.stride(0), num_heads, head_size,
block_size, data_block_stride, data_head_stride,
data_block_offset_stride, scale_block_stride, scale_head_stride,
scale_block_offset_stride);
});
}
...@@ -10,7 +10,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck ...@@ -10,7 +10,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import nvfp4_kv_cache_split_views, set_random_seed
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")] COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
DTYPES = [torch.bfloat16, torch.float] DTYPES = [torch.bfloat16, torch.float]
...@@ -172,7 +172,7 @@ def test_reshape_and_cache( ...@@ -172,7 +172,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE + ["nvfp4"])
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) @pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES) @pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES)
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS) @pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
...@@ -202,6 +202,25 @@ def test_reshape_and_cache_flash( ...@@ -202,6 +202,25 @@ def test_reshape_and_cache_flash(
if kv_scale_type == "attn_head" and implementation != "cuda": if kv_scale_type == "attn_head" and implementation != "cuda":
pytest.skip("Only CUDA implementation supports attn_head scaling.") pytest.skip("Only CUDA implementation supports attn_head scaling.")
if kv_cache_dtype == "nvfp4":
if not current_platform.has_device_capability(100):
pytest.skip("NVFP4 requires compute capability >= 10.0 (Blackwell).")
if implementation != "cuda":
pytest.skip("NVFP4 only supports CUDA implementation.")
if kv_scale_type != "tensor":
pytest.skip("NVFP4 only supports per-tensor scaling.")
if head_size % 16 != 0:
pytest.skip("NVFP4 requires head_size divisible by 16.")
if (head_size // 16) % 4 != 0:
pytest.skip(
"NVFP4 requires (head_size // 16) divisible by 4 "
"for 4x4 block scale swizzle."
)
if block_size % 4 != 0:
pytest.skip("NVFP4 requires block_size divisible by 4.")
if dtype not in (torch.float16, torch.bfloat16):
pytest.skip("NVFP4 quantization only supports fp16/bf16 input.")
# fp8 conversion requires continugous memory buffer. Reduce the number of # fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory. # blocks and tokens to consume less memory.
num_tokens = num_tokens // 2 num_tokens = num_tokens // 2
...@@ -229,7 +248,23 @@ def test_reshape_and_cache_flash( ...@@ -229,7 +248,23 @@ def test_reshape_and_cache_flash(
del key_caches del key_caches
del value_caches del value_caches
if kv_scale_type == "tensor": # For nvfp4, the factory returns kv[:, 0] and kv[:, 1] like all dtypes.
# Split views are still needed for dequant verification.
key_scale_cache = None
value_scale_cache = None
nvfp4_key_data = None
nvfp4_value_data = None
if kv_cache_dtype == "nvfp4":
(nvfp4_key_data,), (key_scale_cache,) = nvfp4_kv_cache_split_views(key_cache)
(nvfp4_value_data,), (value_scale_cache,) = nvfp4_kv_cache_split_views(
value_cache
)
if kv_cache_dtype == "nvfp4":
# Global scale = amax / 448 (per-tensor)
k_scale = (key.abs().amax() / 448.0).to(torch.float32)
v_scale = (value.abs().amax() / 448.0).to(torch.float32)
elif kv_scale_type == "tensor":
k_scale = (key.amax() / 64.0).to(torch.float32) k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32)
else: # "attn_head" else: # "attn_head"
...@@ -240,6 +275,7 @@ def test_reshape_and_cache_flash( ...@@ -240,6 +275,7 @@ def test_reshape_and_cache_flash(
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3) y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
return y.contiguous() return y.contiguous()
if kv_cache_dtype != "nvfp4":
key_cache_compact = permute_and_compact(key_cache) key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache) value_cache_compact = permute_and_compact(value_cache)
...@@ -257,7 +293,7 @@ def test_reshape_and_cache_flash( ...@@ -257,7 +293,7 @@ def test_reshape_and_cache_flash(
result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1) result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
output.copy_(result) output.copy_(result)
# Clone the KV caches. # Clone the KV caches (for non-nvfp4, used as reference baseline).
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype) convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
...@@ -265,11 +301,13 @@ def test_reshape_and_cache_flash( ...@@ -265,11 +301,13 @@ def test_reshape_and_cache_flash(
convert_fp8_local( convert_fp8_local(
cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype
) )
else: elif kv_cache_dtype != "nvfp4":
cloned_key_cache = key_cache_compact.clone() cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone() cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
if implementation == "cuda": if implementation == "cuda":
if kv_cache_dtype != "nvfp4":
opcheck( opcheck(
torch.ops._C_cache_ops.reshape_and_cache_flash, torch.ops._C_cache_ops.reshape_and_cache_flash,
( (
...@@ -309,6 +347,46 @@ def test_reshape_and_cache_flash( ...@@ -309,6 +347,46 @@ def test_reshape_and_cache_flash(
k_scale, k_scale,
v_scale, v_scale,
) )
if kv_cache_dtype == "nvfp4":
# Verify NVFP4 by dequantizing the entire cache and comparing
# the written positions against original bf16 values.
# Same pattern as FP8: dequant whole cache, then extract and compare.
from tests.kernels.quantization.nvfp4_utils import (
dequant_nvfp4_kv_cache,
)
def dequant_nvfp4_cache_nhd(data_cache, scale_cache, global_scale):
# data_cache: [N, T, H, data_dim] NHD (contiguous inner dims)
# scale_cache: [N, T, H, scale_dim] NHD (contiguous inner dims)
# Permute to HND layout for the dequant utility.
data_hnd = data_cache.permute(0, 2, 1, 3)
scale_hnd = scale_cache.permute(0, 2, 1, 3)
result_hnd = dequant_nvfp4_kv_cache(
data_hnd, scale_hnd, global_scale, head_size, block_size
)
return result_hnd.permute(0, 2, 1, 3) # back to [N, T, H, D]
result_key_cache = dequant_nvfp4_cache_nhd(
nvfp4_key_data, key_scale_cache, k_scale.item()
)
result_value_cache = dequant_nvfp4_cache_nhd(
nvfp4_value_data, value_scale_cache, v_scale.item()
)
# Flatten [num_blocks, block_size] → [num_slots] and index by slot_mapping.
num_slots = num_blocks * block_size
result_key_flat = result_key_cache.reshape(num_slots, num_heads, head_size)
result_value_flat = result_value_cache.reshape(num_slots, num_heads, head_size)
torch.testing.assert_close(
result_key_flat[slot_mapping], key.float(), atol=1.5, rtol=0.5
)
torch.testing.assert_close(
result_value_flat[slot_mapping], value.float(), atol=1.5, rtol=0.5
)
return
key_cache_compact = permute_and_compact(key_cache) key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache) value_cache_compact = permute_and_compact(value_cache)
......
...@@ -88,6 +88,60 @@ def break_fp4_bytes(a, dtype): ...@@ -88,6 +88,60 @@ def break_fp4_bytes(a, dtype):
return values.reshape(m, n * 2).to(dtype=dtype) return values.reshape(m, n * 2).to(dtype=dtype)
def dequant_nvfp4_kv_cache(
fp4_data: torch.Tensor,
block_scale: torch.Tensor,
global_scale: float,
head_size: int,
block_size: int,
) -> torch.Tensor:
"""Dequantize an NVFP4 KV cache with 4x4-swizzled block scales.
The input must be in HND layout so that the last two dims are
(block_size, last_dim). For NHD caches, permute to HND first.
Args:
fp4_data: [..., num_heads, block_size, head_size//2] uint8 packed fp4.
block_scale: [..., num_heads, block_size, head_size//16] fp8 block
scales (as uint8 or float8_e4m3fn).
global_scale: checkpoint dequant scale (k_scale or v_scale).
head_size: head dimension.
block_size: page size.
Returns:
[..., num_heads, block_size, head_size] float32.
"""
data_dim = head_size // 2
scale_dim = head_size // 16
fp4_packed = fp4_data
sf_swizzled = block_scale.view(torch.uint8)
# Unswizzle 4x4 block scales on (block_size, scale_dim) plane.
# [..., T, S] → [..., T//4, 4, sg, 4] → permute → [..., T, S]
batch_shape = sf_swizzled.shape[:-2]
T, S = block_size, scale_dim
sg = S // 4
sf_reshape = sf_swizzled.reshape(*batch_shape, T // 4, 4, sg, 4)
ndim = sf_reshape.ndim
# Swap the last four dims: (..., T//4, 4, sg, 4) → (..., T//4, 4, 4, sg)
perm = list(range(ndim - 4)) + [ndim - 4, ndim - 1, ndim - 3, ndim - 2]
sf_linear = sf_reshape.permute(*perm).reshape(*batch_shape, T, S)
sf_f32 = sf_linear.view(torch.float8_e4m3fn).to(torch.float32)
# Unpack fp4
shape = fp4_packed.shape # [..., T, data_dim]
fp4_flat = fp4_packed.reshape(-1, data_dim)
fp4_vals = break_fp4_bytes(fp4_flat, torch.float32)
fp4_vals = fp4_vals.reshape(*shape[:-1], head_size)
# Dequant: fp4_val * block_scale * global_scale per 16-element group
return (
fp4_vals.reshape(*shape[:-1], scale_dim, 16)
* (sf_f32 * global_scale).unsqueeze(-1)
).reshape(*shape[:-1], head_size)
def get_nvfp4_global_scale(a: torch.Tensor): def get_nvfp4_global_scale(a: torch.Tensor):
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32) return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
......
...@@ -30,6 +30,7 @@ CacheDType = Literal[ ...@@ -30,6 +30,7 @@ CacheDType = Literal[
"turboquant_3bit_nc", "turboquant_3bit_nc",
"int8_per_token_head", "int8_per_token_head",
"fp8_per_token_head", "fp8_per_token_head",
"nvfp4",
] ]
MambaDType = Literal["auto", "float32", "float16"] MambaDType = Literal["auto", "float32", "float16"]
MambaCacheMode = Literal["all", "align", "none"] MambaCacheMode = Literal["all", "align", "none"]
......
...@@ -387,7 +387,9 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -387,7 +387,9 @@ class Attention(nn.Module, AttentionLayerBase):
self.query_quant = None self.query_quant = None
if ( if (
self.impl.supports_quant_query_input self.impl.supports_quant_query_input
and self.kv_cache_dtype.startswith("fp8") and (
self.kv_cache_dtype.startswith("fp8") or self.kv_cache_dtype == "nvfp4"
)
and not self.kv_cache_dtype.endswith("per_token_head") and not self.kv_cache_dtype.endswith("per_token_head")
): ):
is_per_head = ( is_per_head = (
...@@ -492,7 +494,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -492,7 +494,7 @@ class Attention(nn.Module, AttentionLayerBase):
# which reduces overheads during decoding. # which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops # Otherwise queries are quantized using custom ops
# which causes decoding overheads # which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} assert self.kv_cache_dtype in {"fp8", "fp8_e4m3", "nvfp4"}
# check if query quantization is supported # check if query quantization is supported
if self.impl.supports_quant_query_input: if self.impl.supports_quant_query_input:
......
...@@ -46,6 +46,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -46,6 +46,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"turboquant_4bit_nc": torch.uint8, "turboquant_4bit_nc": torch.uint8,
"turboquant_k3v4_nc": torch.uint8, "turboquant_k3v4_nc": torch.uint8,
"turboquant_3bit_nc": torch.uint8, "turboquant_3bit_nc": torch.uint8,
"nvfp4": torch.uint8,
} }
TORCH_DTYPE_TO_NUMPY_DTYPE = { TORCH_DTYPE_TO_NUMPY_DTYPE = {
...@@ -59,17 +60,19 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = { ...@@ -59,17 +60,19 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = { MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
# TODO: Add more modelopt kv cache dtype
# mappings here when it supported by some attention backend
# (for example supports nvfp4).
"fp8": "fp8_e4m3", "fp8": "fp8_e4m3",
"nvfp4": "nvfp4",
} }
T = TypeVar("T") T = TypeVar("T")
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8") or kv_cache_dtype.endswith("per_token_head") return (
kv_cache_dtype.startswith("fp8")
or kv_cache_dtype.endswith("per_token_head")
or kv_cache_dtype == "nvfp4"
)
def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool: def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
...@@ -299,6 +302,8 @@ def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None: ...@@ -299,6 +302,8 @@ def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None:
and kv_algo.get("type") == "float" and kv_algo.get("type") == "float"
): ):
kv_algo = "fp8" kv_algo = "fp8"
elif kv_algo.get("num_bits") == 4 and kv_algo.get("type") == "float":
kv_algo = "nvfp4"
else: else:
# Unknown/unsupported format - return "auto" as safe fallback # Unknown/unsupported format - return "auto" as safe fallback
logger.warning( logger.warning(
...@@ -375,6 +380,95 @@ def set_random_seed(seed: int | None) -> None: ...@@ -375,6 +380,95 @@ def set_random_seed(seed: int | None) -> None:
current_platform.manual_seed_all(seed) current_platform.manual_seed_all(seed)
def nvfp4_kv_cache_full_dim(head_size: int) -> int:
"""Packed last dim for NVFP4 KV cache: fp4 data + fp8 block scales."""
return head_size // 2 + head_size // 16
def _nvfp4_split_data_scale(
kv_side: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Split a single NVFP4 KV-side buffer into data and scale views.
The input is a 4D tensor for one KV side (K or V) whose last
dimension is ``full_dim = data_dim + scale_dim``. The physical
layout within each side is [data | scale], both packed contiguously.
Args:
kv_side: 4D uint8 tensor with shape
``(num_pages, dim_1, dim_2, full_dim)``.
May be in any permutation order (NHD or HND).
Returns:
``(data, scale)`` where
``data`` is a uint8 view with shape
``(num_pages, dim_1, dim_2, data_dim)``.
``scale`` is a float8_e4m3fn view with shape
``(num_pages, dim_1, dim_2, scale_dim)``.
"""
num_pages = kv_side.shape[0]
dim_1, dim_2 = kv_side.shape[1], kv_side.shape[2]
full_dim = kv_side.shape[3]
data_dim = full_dim * 8 // 9
scale_dim = full_dim - data_dim
data_per_kv = dim_1 * dim_2 * data_dim
page_bytes = kv_side.stride(0)
# Derive inner strides from the kv_side strides, scaling by the
# ratio of the target dim to full_dim. This preserves the physical
# layout (NHD vs HND) encoded in the input tensor's strides.
s1 = kv_side.stride(1) * data_dim // full_dim
s2 = kv_side.stride(2) * data_dim // full_dim
data_shape = (num_pages, dim_1, dim_2, data_dim)
data_strides = (page_bytes, s1, s2, 1)
s1_s = kv_side.stride(1) * scale_dim // full_dim
s2_s = kv_side.stride(2) * scale_dim // full_dim
scale_shape = (num_pages, dim_1, dim_2, scale_dim)
scale_strides = (page_bytes, s1_s, s2_s, 1)
base = kv_side.storage_offset()
data = torch.as_strided(kv_side, data_shape, data_strides, storage_offset=base)
scale = torch.as_strided(
kv_side, scale_shape, scale_strides, storage_offset=base + data_per_kv
).view(torch.float8_e4m3fn)
return data, scale
def nvfp4_kv_cache_split_views(kv_cache: torch.Tensor) -> tuple[tuple, tuple]:
"""Split an NVFP4 KV cache tensor into data and scale views.
Accepts either a 5D tensor ``(num_pages, 2, dim_2, dim_3, full_dim)``
or a 4D single-side tensor ``(num_pages, dim_2, dim_3, full_dim)``.
Per-page layout: [K_data | K_scale | V_data | V_scale].
Each KV side is self-contained (data followed by its scale), so the
5D case simply splits each side independently.
The returned views are in the same dim order as the input (NHD or
HND), so callers get views matching whichever order they passed in.
Args:
kv_cache: 5D or 4D uint8 tensor where the last dimension is
``full_dim = data_dim + scale_dim = 9 * head_size / 16``.
Returns:
For 5D input:
``(k_data, v_data), (k_scale, v_scale)``
For 4D input (single KV side):
``(data,), (scale,)``
"""
if kv_cache.dim() == 4:
data, scale = _nvfp4_split_data_scale(kv_cache)
return (data,), (scale,)
k_data, k_scale = _nvfp4_split_data_scale(kv_cache[:, 0])
v_data, v_scale = _nvfp4_split_data_scale(kv_cache[:, 1])
return (k_data, v_data), (k_scale, v_scale)
def create_kv_caches_with_random_flash( def create_kv_caches_with_random_flash(
num_blocks: int, num_blocks: int,
block_size: int, block_size: int,
...@@ -401,6 +495,22 @@ def create_kv_caches_with_random_flash( ...@@ -401,6 +495,22 @@ def create_kv_caches_with_random_flash(
value_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
if cache_dtype == "nvfp4":
# Full page dim: fp4 data + fp8 block scales per head.
# Per page layout: [K_data | K_scale | V_data | V_scale]
# Returns [:, 0] and [:, 1] like all other dtypes.
full_dim = nvfp4_kv_cache_full_dim(head_size)
nvfp4_shape = (num_blocks, 2, block_size, num_heads, full_dim)
nvfp4_phys = tuple(nvfp4_shape[i] for i in stride_order)
inv = [stride_order.index(i) for i in range(len(stride_order))]
key_value_cache = torch.randint(
0,
256,
nvfp4_phys,
dtype=dtype,
device=device,
).permute(*inv)
else:
key_value_cache = torch.empty( key_value_cache = torch.empty(
size=kv_cache_allocation_shape, dtype=dtype, device=device size=kv_cache_allocation_shape, dtype=dtype, device=device
).permute(*stride_order) ).permute(*stride_order)
......
...@@ -42,7 +42,12 @@ from vllm.utils.flashinfer import ( ...@@ -42,7 +42,12 @@ from vllm.utils.flashinfer import (
) )
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import is_quantized_kv_cache, is_strictly_contiguous from vllm.utils.torch_utils import (
is_quantized_kv_cache,
is_strictly_contiguous,
nvfp4_kv_cache_full_dim,
nvfp4_kv_cache_split_views,
)
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -355,6 +360,10 @@ class FlashInferBackend(AttentionBackend): ...@@ -355,6 +360,10 @@ class FlashInferBackend(AttentionBackend):
head_size: int, head_size: int,
cache_dtype_str: str = "auto", cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if cache_dtype_str == "nvfp4":
# Packed layout: fp4 data + fp8 block scales in last dim
last_dim = nvfp4_kv_cache_full_dim(head_size)
return (num_blocks, 2, block_size, num_kv_heads, last_dim)
return (num_blocks, 2, block_size, num_kv_heads, head_size) return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod @staticmethod
...@@ -608,11 +617,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -608,11 +617,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.cache_dtype = self.cache_config.cache_dtype self.cache_dtype = self.cache_config.cache_dtype
# Cannot use self.kv_cache_spec.dtype here because kv_cache_spec # Cannot use self.kv_cache_spec.dtype here because kv_cache_spec
# storage dtype may not be the same as the op dtype (uint8 vs fp8_e4m3) # storage dtype may not be the same as the op dtype (uint8 vs fp8_e4m3)
self.is_kvcache_nvfp4 = self.cache_dtype == "nvfp4"
if self.is_kvcache_nvfp4:
# For NVFP4, kv_cache_dtype stays as the string "nvfp4"
# which is passed to FlashInferImpl
self.kv_cache_dtype = self.cache_dtype
raise NotImplementedError("nvfp4 KV cache is not yet supported")
else:
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.cache_dtype self.cache_dtype
) )
else: else:
self.cache_dtype = "auto" self.cache_dtype = "auto"
self.is_kvcache_nvfp4 = False
assert self.kv_cache_spec.dtype == self.model_config.dtype assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype
...@@ -626,6 +643,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -626,6 +643,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
can_use_trtllm can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization and not vllm_config.attention_config.disable_flashinfer_q_quantization
): ):
if self.is_kvcache_nvfp4:
# NVFP4 KV cache uses FP8 quantized queries
self.q_data_type = FlashInferBackend.get_fp8_dtype_for_flashinfer(
"fp8_e4m3"
)
else:
self.q_data_type = self.kv_cache_dtype self.q_data_type = self.kv_cache_dtype
else: else:
self.q_data_type = self.model_config.dtype self.q_data_type = self.model_config.dtype
...@@ -1228,6 +1251,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -1228,6 +1251,8 @@ class FlashInferImpl(AttentionImpl):
self.sliding_window[0] if self.sliding_window is not None else -1 self.sliding_window[0] if self.sliding_window is not None else -1
) )
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.is_kvcache_nvfp4 = kv_cache_dtype == "nvfp4"
self.fp4_data_dim = head_size // 2 if self.is_kvcache_nvfp4 else 0
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
...@@ -1406,7 +1431,16 @@ class FlashInferImpl(AttentionImpl): ...@@ -1406,7 +1431,16 @@ class FlashInferImpl(AttentionImpl):
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
stride_order = FlashInferBackend.get_kv_cache_stride_order() stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order) kv_cache_permute = kv_cache.permute(*stride_order) # HND and contiguous
# For NVFP4, the kv_cache last dim is full_dim (data + scale packed).
# Split into correctly-strided data and scale views.
nvfp4_kv_data = None
nvfp4_kv_block_scales = None
if self.is_kvcache_nvfp4:
nvfp4_kv_data, nvfp4_kv_block_scales = nvfp4_kv_cache_split_views(
kv_cache_permute
)
use_dcp = self.dcp_world_size > 1 use_dcp = self.dcp_world_size > 1
...@@ -1490,8 +1524,20 @@ class FlashInferImpl(AttentionImpl): ...@@ -1490,8 +1524,20 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None assert self.o_sf_scale is None
out = output[num_decode_tokens:] out = output[num_decode_tokens:]
if attn_metadata.q_data_type != FP8_DTYPE and is_quantized_kv_cache( prefill_kv_block_scales = None
self.kv_cache_dtype if self.is_kvcache_nvfp4:
# NVFP4 trtllm-gen kernel requires FP8 query.
assert attn_metadata.q_data_type == FP8_DTYPE, (
"NVFP4 KV cache requires FP8 quantized queries for "
"trtllm-gen prefill. Set "
"disable_flashinfer_q_quantization=False."
)
mock_kv_cache = nvfp4_kv_data
mock_block_table = block_tables_prefill
prefill_kv_block_scales = nvfp4_kv_block_scales # noqa: F841
elif (
attn_metadata.q_data_type != FP8_DTYPE
and self.kv_cache_dtype.startswith("fp8")
): ):
# TRTLLM prefill attention does not support BF16 Q # TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention # and fp8 kv cache. So to enable prefill attention
...@@ -1636,7 +1682,9 @@ class FlashInferImpl(AttentionImpl): ...@@ -1636,7 +1682,9 @@ class FlashInferImpl(AttentionImpl):
trtllm_batch_decode_with_kv_cache( trtllm_batch_decode_with_kv_cache(
query=decode_query, query=decode_query,
kv_cache=kv_cache_permute, kv_cache=nvfp4_kv_data
if self.is_kvcache_nvfp4
else kv_cache_permute,
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=block_tables_decode, block_tables=block_tables_decode,
seq_lens=seq_lens_decode, seq_lens=seq_lens_decode,
...@@ -1667,11 +1715,13 @@ class FlashInferImpl(AttentionImpl): ...@@ -1667,11 +1715,13 @@ class FlashInferImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash # and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of # op uses the slot_mapping's shape to determine the number of
# actual tokens. # actual tokens.
k_cache = kv_cache[:, 0]
v_cache = kv_cache[:, 1]
torch.ops._C_cache_ops.reshape_and_cache_flash( torch.ops._C_cache_ops.reshape_and_cache_flash(
key, key,
value, value,
kv_cache[:, 0], k_cache,
kv_cache[:, 1], v_cache,
slot_mapping, slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale,
......
...@@ -17,7 +17,7 @@ from vllm.logger import init_logger ...@@ -17,7 +17,7 @@ from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_dtype_size from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -38,11 +38,20 @@ class KVQuantMode(IntEnum): ...@@ -38,11 +38,20 @@ class KVQuantMode(IntEnum):
FP8_PER_TENSOR = 1 # per-tensor scales (current fp8 path) FP8_PER_TENSOR = 1 # per-tensor scales (current fp8 path)
INT8_PER_TOKEN_HEAD = 2 # per-token-head dynamic scales for int8 INT8_PER_TOKEN_HEAD = 2 # per-token-head dynamic scales for int8
FP8_PER_TOKEN_HEAD = 3 # per-token-head dynamic scales for fp8 FP8_PER_TOKEN_HEAD = 3 # per-token-head dynamic scales for fp8
NVFP4 = 4 # packed fp4 data + fp8 block scales
@property @property
def is_per_token_head(self) -> bool: def is_per_token_head(self) -> bool:
"""True for any per-token-head quantization mode.""" """True for any per-token-head quantization mode."""
return self >= 2 return self in (
KVQuantMode.INT8_PER_TOKEN_HEAD,
KVQuantMode.FP8_PER_TOKEN_HEAD,
)
@property
def is_nvfp4(self) -> bool:
"""True for NVFP4 packed quantization mode."""
return self == KVQuantMode.NVFP4
def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode: def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
...@@ -51,7 +60,9 @@ def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode: ...@@ -51,7 +60,9 @@ def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
return KVQuantMode.INT8_PER_TOKEN_HEAD return KVQuantMode.INT8_PER_TOKEN_HEAD
if kv_cache_dtype == "fp8_per_token_head": if kv_cache_dtype == "fp8_per_token_head":
return KVQuantMode.FP8_PER_TOKEN_HEAD return KVQuantMode.FP8_PER_TOKEN_HEAD
if kv_cache_dtype.startswith("fp8"): if kv_cache_dtype == "nvfp4":
return KVQuantMode.NVFP4
if isinstance(kv_cache_dtype, str) and kv_cache_dtype.startswith("fp8"):
return KVQuantMode.FP8_PER_TENSOR return KVQuantMode.FP8_PER_TENSOR
return KVQuantMode.NONE return KVQuantMode.NONE
...@@ -237,6 +248,19 @@ class FullAttentionSpec(AttentionSpec): ...@@ -237,6 +248,19 @@ class FullAttentionSpec(AttentionSpec):
@property @property
def real_page_size_bytes(self) -> int: def real_page_size_bytes(self) -> int:
if self.kv_quant_mode.is_nvfp4:
# Packed layout per head: fp4 data + fp8 block scales.
# fp4 data: head_size//2 bytes (2 fp4 values per byte)
# fp8 block scale: head_size//16 bytes (1 scale per 16 elements)
last_dim = nvfp4_kv_cache_full_dim(
self.head_size
) + nvfp4_kv_cache_full_dim(self.head_size_v)
return (
self.block_size
* self.num_kv_heads
* last_dim
* get_dtype_size(self.dtype)
)
return ( return (
self.block_size self.block_size
* self.num_kv_heads * self.num_kv_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment