Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6b2b7bd0
Unverified
Commit
6b2b7bd0
authored
Apr 17, 2026
by
sychen52
Committed by
GitHub
Apr 17, 2026
Browse files
Add nvfp4 support to reshape_and_cache_flash (#37332)
Signed-off-by:
Shiyang Chen
<
shiychen@nvidia.com
>
parent
70770268
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
679 additions
and
51 deletions
+679
-51
CMakeLists.txt
CMakeLists.txt
+14
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+22
-2
csrc/nvfp4_kv_cache_kernels.cu
csrc/nvfp4_kv_cache_kernels.cu
+275
-0
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+99
-21
tests/kernels/quantization/nvfp4_utils.py
tests/kernels/quantization/nvfp4_utils.py
+54
-0
vllm/config/cache.py
vllm/config/cache.py
+1
-0
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+4
-2
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+122
-12
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+61
-11
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+27
-3
No files found.
CMakeLists.txt
View file @
6b2b7bd0
...
...
@@ -923,6 +923,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
list
(
APPEND VLLM_STABLE_EXT_SRC
"
${
SRCS
}
"
)
# nvfp4_kv_cache_kernels uses non-stable torch API and is called directly
# from cache_kernels.cu, so it belongs in _C rather than _C_stable.
set
(
NVFP4_KV_SRC
"csrc/nvfp4_kv_cache_kernels.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
NVFP4_KV_SRC
}
"
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
target_sources
(
_C PRIVATE
${
NVFP4_KV_SRC
}
)
target_compile_definitions
(
_C PRIVATE ENABLE_NVFP4_SM120=1
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_NVFP4_SM120=1"
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_CUTLASS_MOE_SM120=1"
)
message
(
STATUS
"Building NVFP4 for archs:
${
FP4_ARCHS
}
"
)
...
...
@@ -949,6 +957,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
list
(
APPEND VLLM_STABLE_EXT_SRC
"
${
SRCS
}
"
)
set
(
NVFP4_KV_SRC
"csrc/nvfp4_kv_cache_kernels.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
NVFP4_KV_SRC
}
"
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
target_sources
(
_C PRIVATE
${
NVFP4_KV_SRC
}
)
target_compile_definitions
(
_C PRIVATE ENABLE_NVFP4_SM100=1
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_NVFP4_SM100=1"
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_CUTLASS_MOE_SM100=1"
)
message
(
STATUS
"Building NVFP4 for archs:
${
FP4_ARCHS
}
"
)
...
...
csrc/cache_kernels.cu
View file @
6b2b7bd0
...
...
@@ -724,6 +724,28 @@ void reshape_and_cache_flash(
int
num_tokens
=
slot_mapping
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
kv_cache_dtype
==
"nvfp4"
)
{
#if defined(ENABLE_NVFP4_SM100) || defined(ENABLE_NVFP4_SM120)
// NVFP4 dispatch is compiled separately for SM100+.
extern
void
reshape_and_cache_nvfp4_dispatch
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
);
reshape_and_cache_nvfp4_dispatch
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
k_scale
,
v_scale
);
return
;
#else
TORCH_CHECK
(
false
,
"NVFP4 KV cache requires SM100+ (Blackwell). "
"Please rebuild vllm with a Blackwell-compatible CUDA target."
);
#endif
}
// Original FP8/auto path.
int
block_size
=
key_cache
.
size
(
1
);
int64_t
key_stride
=
key
.
stride
(
0
);
...
...
@@ -741,8 +763,6 @@ void reshape_and_cache_flash(
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_BY_KV_CACHE_DTYPE
(
key
.
dtype
(),
kv_cache_dtype
,
CALL_RESHAPE_AND_CACHE_FLASH
);
...
...
csrc/nvfp4_kv_cache_kernels.cu
0 → 100644
View file @
6b2b7bd0
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// NVFP4 KV cache store kernel.
// Quantizes bf16 key/value to packed FP4 + FP8 block scales and writes them
// into the paged KV cache.
//
// Per page layout: [K_data | K_scale | V_data | V_scale]
// Both data and scale regions are contiguous per head, enabling direct
// TMA descriptor use.
//
// Reuses device functions from nvfp4_utils.cuh:
// - cvt_warp_fp16_to_fp4() for bf16 → fp4 quantization + block scale
// - pack_fp4() for packing float pairs to fp4
// - reciprocal_approximate_ftz() for fast reciprocal
#define NVFP4_ENABLE_ELTS16 1
#include "libtorch_stable/quantization/fp4/nvfp4_utils.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "dispatch_utils.h"
namespace
vllm
{
// Compute swizzled scale offset for SM100 trtllm-gen MHA kernel.
// The swizzle pattern for HND layout is:
// [T//4, 4, 4, S//4] → permute(0, 2, 3, 1) → reshape to [T, S]
// where T = block_size (page_size), S = scale_dim = head_size // 16.
//
// For a linear (t, s) position, the swizzled position is:
// swizzled_t = (t / 4) * 4 + (s / (S / 4))
// swizzled_s = (s % (S / 4)) * 4 + (t % 4)
__device__
__forceinline__
int
swizzle_scale_offset
(
int
t
,
int
s
,
int
scale_dim
)
{
int
s_group
=
scale_dim
/
4
;
int
swizzled_t
=
(
t
/
4
)
*
4
+
(
s
/
s_group
);
int
swizzled_s
=
(
s
%
s_group
)
*
4
+
(
t
%
4
);
return
swizzled_t
*
scale_dim
+
swizzled_s
;
}
// Kernel: quantize bf16 key/value to NVFP4 and store in paged KV cache.
//
// Takes separate data and scale cache pointers for K and V.
// Within each KV side, data and scale are separate contiguous regions.
//
// Threading: one CUDA block per token, threads process heads and
// groups of 16 elements within each head.
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_nvfp4_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
uint8_t
*
__restrict__
key_data_cache
,
// data region for K
uint8_t
*
__restrict__
value_data_cache
,
// data region for V
uint8_t
*
__restrict__
key_scale_cache
,
// scale region for K
uint8_t
*
__restrict__
value_scale_cache
,
// scale region for V
const
int64_t
*
__restrict__
slot_mapping
,
// [num_actual_tokens]
const
float
*
__restrict__
k_scale_ptr
,
// pointer to checkpoint k_scale
const
float
*
__restrict__
v_scale_ptr
,
// pointer to checkpoint v_scale
const
int64_t
key_stride
,
// key.stride(0) in elements
const
int64_t
value_stride
,
// value.stride(0) in elements
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int64_t
data_block_stride
,
// data cache stride for dim 0
const
int64_t
data_head_stride
,
// data cache stride for heads
const
int64_t
data_block_offset_stride
,
// data cache stride for tokens
const
int64_t
scale_block_stride
,
// scale cache stride for dim 0
const
int64_t
scale_head_stride
,
// scale cache stride for heads
const
int64_t
scale_block_offset_stride
// scale cache stride for tokens
)
{
using
CudaType
=
typename
CUDATypeConverter
<
scalar_t
>::
Type
;
using
PVec
=
PackedVec
<
CudaType
,
CVT_FP4_PACK16
>
;
static
constexpr
int
ELTS
=
CVT_FP4_ELTS_PER_THREAD
;
// 16 or 8
static
constexpr
int
THREADS_PER_SF
=
CVT_FP4_SF_VEC_SIZE
/
ELTS
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
if
(
slot_idx
<
0
)
return
;
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int
block_offset
=
static_cast
<
int
>
(
slot_idx
%
block_size
);
const
int
scale_dim
=
head_size
/
16
;
const
int
groups_per_head
=
head_size
/
CVT_FP4_SF_VEC_SIZE
;
const
int
total_groups
=
num_heads
*
groups_per_head
;
const
int
tid
=
threadIdx
.
x
;
const
int
num_thread_groups
=
blockDim
.
x
/
THREADS_PER_SF
;
const
int
tg_id
=
tid
/
THREADS_PER_SF
;
const
int
tg_lane
=
tid
%
THREADS_PER_SF
;
// Process both K (kv=0) and V (kv=1)
#pragma unroll
for
(
int
kv
=
0
;
kv
<
2
;
kv
++
)
{
const
scalar_t
*
__restrict__
src
=
(
kv
==
0
)
?
key
:
value
;
const
float
global_scale
=
1.0
f
/
((
kv
==
0
)
?
*
k_scale_ptr
:
*
v_scale_ptr
);
const
int64_t
src_stride
=
(
kv
==
0
)
?
key_stride
:
value_stride
;
uint8_t
*
__restrict__
data_cache
=
(
kv
==
0
)
?
key_data_cache
:
value_data_cache
;
uint8_t
*
__restrict__
sc_cache
=
(
kv
==
0
)
?
key_scale_cache
:
value_scale_cache
;
// Source pointer for this token (use actual stride, not assumed contiguous)
const
CudaType
*
__restrict__
token_src
=
reinterpret_cast
<
const
CudaType
*>
(
src
)
+
token_idx
*
src_stride
;
// Destination bases in data and scale caches for this token's block
uint8_t
*
__restrict__
data_block
=
data_cache
+
block_idx
*
data_block_stride
;
uint8_t
*
__restrict__
scale_block
=
sc_cache
+
block_idx
*
scale_block_stride
;
for
(
int
g
=
tg_id
;
g
<
total_groups
;
g
+=
num_thread_groups
)
{
const
int
head
=
g
/
groups_per_head
;
const
int
group_in_head
=
g
%
groups_per_head
;
// Load 16 (or 8) bf16 elements from source
PVec
in_vec
;
const
CudaType
*
__restrict__
src_ptr
=
token_src
+
head
*
head_size
+
group_in_head
*
CVT_FP4_SF_VEC_SIZE
+
tg_lane
*
ELTS
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ELTS
/
2
;
i
++
)
{
in_vec
.
elts
[
i
]
=
reinterpret_cast
<
const
typename
PackedTypeConverter
<
CudaType
>::
Type
*>
(
src_ptr
)[
i
];
}
// Quantize: produces packed fp4 and writes scale factor.
uint8_t
sf_val
;
uint8_t
*
sf_out_ptr
=
(
tg_lane
==
0
)
?
&
sf_val
:
nullptr
;
fp4_packed_t
packed
=
cvt_warp_fp16_to_fp4
<
CudaType
,
THREADS_PER_SF
>
(
in_vec
,
global_scale
,
sf_out_ptr
);
// Write packed FP4 data to data cache
uint8_t
*
__restrict__
data_dst
=
data_block
+
head
*
data_head_stride
+
block_offset
*
data_block_offset_stride
;
#if CVT_FP4_PACK16
{
// 16 elements → 8 bytes (u32x2)
int
data_byte_offset
=
group_in_head
*
8
;
reinterpret_cast
<
uint64_t
*>
(
data_dst
+
data_byte_offset
)[
0
]
=
(
uint64_t
(
packed
.
hi
)
<<
32
)
|
uint64_t
(
packed
.
lo
);
}
#else
{
// 8 elements → 4 bytes (uint32_t)
int
data_byte_offset
=
group_in_head
*
CVT_FP4_SF_VEC_SIZE
/
2
+
tg_lane
*
ELTS
/
2
;
reinterpret_cast
<
uint32_t
*>
(
data_dst
+
data_byte_offset
)[
0
]
=
packed
;
}
#endif
// Write block scale to scale cache.
// K (kv==0): linear layout (no swizzle).
// V (kv==1): swizzled layout for SM100 trtllm-gen MHA kernel.
if
(
sf_out_ptr
!=
nullptr
)
{
int
scale_idx
=
group_in_head
;
uint8_t
*
__restrict__
scale_dst
;
if
(
kv
==
0
)
{
scale_dst
=
scale_block
+
head
*
scale_head_stride
+
block_offset
*
scale_block_offset_stride
+
scale_idx
;
}
else
{
int
swizzled_offset
=
swizzle_scale_offset
(
block_offset
,
scale_idx
,
scale_dim
);
int
swizzled_t
=
swizzled_offset
/
scale_dim
;
int
swizzled_s
=
swizzled_offset
%
scale_dim
;
scale_dst
=
scale_block
+
head
*
scale_head_stride
+
swizzled_t
*
scale_block_offset_stride
+
swizzled_s
;
}
*
scale_dst
=
sf_val
;
}
}
}
}
}
// namespace vllm
// Non-template entry point callable from cache_kernels.cu.
// Receives key_cache/value_cache as kv_cache[:, 0] and kv_cache[:, 1].
// Each KV side contains both data and scale:
// page = [K_data | K_scale | V_data | V_scale]
void
reshape_and_cache_nvfp4_dispatch
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
)
{
int
num_tokens
=
slot_mapping
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
data_dim
=
head_size
/
2
;
int
scale_dim
=
head_size
/
16
;
int
full_dim
=
data_dim
+
scale_dim
;
// key_cache is kv_cache[:, 0] with shape
// [num_blocks, block_size, num_heads, full_dim] in logical order.
// Strides encode the physical layout (HND or NHD).
TORCH_CHECK
(
key_cache
.
dim
()
==
4
,
"key_cache must be 4D"
);
TORCH_CHECK
(
key_cache
.
size
(
3
)
==
full_dim
,
"key_cache last dim must be data_dim + scale_dim, got "
,
key_cache
.
size
(
3
),
" expected "
,
full_dim
);
int
block_size
=
key_cache
.
size
(
1
);
TORCH_CHECK
(
head_size
%
16
==
0
,
"head_size must be divisible by 16 for NVFP4 KV cache"
);
TORCH_CHECK
(
block_size
%
4
==
0
,
"block_size must be divisible by 4 for NVFP4 KV cache swizzle"
);
// Detect physical layout from strides (based on full_dim).
// HND: head stride > block_offset stride.
bool
is_hnd
=
key_cache
.
stride
(
2
)
>
key_cache
.
stride
(
1
);
int64_t
data_block_stride
=
key_cache
.
stride
(
0
);
// page_bytes
int64_t
data_head_stride
,
data_block_offset_stride
;
if
(
is_hnd
)
{
data_head_stride
=
(
int64_t
)
block_size
*
data_dim
;
data_block_offset_stride
=
data_dim
;
}
else
{
data_head_stride
=
data_dim
;
data_block_offset_stride
=
(
int64_t
)
num_heads
*
data_dim
;
}
// Page layout: [K_data | K_scale | V_data | V_scale]
// Scale follows data within each KV side.
int64_t
data_per_kv
=
(
int64_t
)
num_heads
*
block_size
*
data_dim
;
uint8_t
*
key_scale_ptr
=
key_cache
.
data_ptr
<
uint8_t
>
()
+
data_per_kv
;
uint8_t
*
value_scale_ptr
=
value_cache
.
data_ptr
<
uint8_t
>
()
+
data_per_kv
;
// Scale strides: same page stride, inner strides from layout.
int64_t
scale_block_stride
=
data_block_stride
;
int64_t
scale_head_stride
,
scale_block_offset_stride
;
if
(
is_hnd
)
{
scale_head_stride
=
(
int64_t
)
block_size
*
scale_dim
;
scale_block_offset_stride
=
scale_dim
;
}
else
{
scale_head_stride
=
scale_dim
;
scale_block_offset_stride
=
(
int64_t
)
num_heads
*
scale_dim
;
}
const
float
*
k_scale_ptr
=
k_scale
.
data_ptr
<
float
>
();
const
float
*
v_scale_ptr
=
v_scale
.
data_ptr
<
float
>
();
int
groups_per_head
=
head_size
/
CVT_FP4_SF_VEC_SIZE
;
int
total_groups
=
num_heads
*
groups_per_head
;
constexpr
int
THREADS_PER_SF
=
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
;
int
num_threads
=
std
::
min
(
total_groups
*
THREADS_PER_SF
,
512
);
num_threads
=
((
num_threads
+
31
)
/
32
)
*
32
;
dim3
grid
(
num_tokens
);
dim3
block
(
num_threads
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
key
.
scalar_type
(),
"reshape_and_cache_nvfp4"
,
[
&
]
{
vllm
::
reshape_and_cache_nvfp4_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
uint8_t
>
(),
value_cache
.
data_ptr
<
uint8_t
>
(),
key_scale_ptr
,
value_scale_ptr
,
slot_mapping
.
data_ptr
<
int64_t
>
(),
k_scale_ptr
,
v_scale_ptr
,
key
.
stride
(
0
),
value
.
stride
(
0
),
num_heads
,
head_size
,
block_size
,
data_block_stride
,
data_head_stride
,
data_block_offset_stride
,
scale_block_stride
,
scale_head_stride
,
scale_block_offset_stride
);
});
}
tests/kernels/attention/test_cache.py
View file @
6b2b7bd0
...
...
@@ -10,7 +10,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
scaled_dequantize
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
nvfp4_kv_cache_split_views
,
set_random_seed
COPYING_DIRECTION
=
[(
"cuda"
,
"cpu"
),
(
"cuda"
,
"cuda"
),
(
"cpu"
,
"cuda"
)]
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -172,7 +172,7 @@ def test_reshape_and_cache(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
+
[
"nvfp4"
]
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_layout"
,
CACHE_LAYOUTS
)
@
pytest
.
mark
.
parametrize
(
"kv_scale_type"
,
KV_SCALE_TYPES
)
@
pytest
.
mark
.
parametrize
(
"implementation"
,
RESHAPE_FLASH_IMPLEMENTATIONS
)
...
...
@@ -202,6 +202,25 @@ def test_reshape_and_cache_flash(
if
kv_scale_type
==
"attn_head"
and
implementation
!=
"cuda"
:
pytest
.
skip
(
"Only CUDA implementation supports attn_head scaling."
)
if
kv_cache_dtype
==
"nvfp4"
:
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
"NVFP4 requires compute capability >= 10.0 (Blackwell)."
)
if
implementation
!=
"cuda"
:
pytest
.
skip
(
"NVFP4 only supports CUDA implementation."
)
if
kv_scale_type
!=
"tensor"
:
pytest
.
skip
(
"NVFP4 only supports per-tensor scaling."
)
if
head_size
%
16
!=
0
:
pytest
.
skip
(
"NVFP4 requires head_size divisible by 16."
)
if
(
head_size
//
16
)
%
4
!=
0
:
pytest
.
skip
(
"NVFP4 requires (head_size // 16) divisible by 4 "
"for 4x4 block scale swizzle."
)
if
block_size
%
4
!=
0
:
pytest
.
skip
(
"NVFP4 requires block_size divisible by 4."
)
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
pytest
.
skip
(
"NVFP4 quantization only supports fp16/bf16 input."
)
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens
=
num_tokens
//
2
...
...
@@ -229,7 +248,23 @@ def test_reshape_and_cache_flash(
del
key_caches
del
value_caches
if
kv_scale_type
==
"tensor"
:
# For nvfp4, the factory returns kv[:, 0] and kv[:, 1] like all dtypes.
# Split views are still needed for dequant verification.
key_scale_cache
=
None
value_scale_cache
=
None
nvfp4_key_data
=
None
nvfp4_value_data
=
None
if
kv_cache_dtype
==
"nvfp4"
:
(
nvfp4_key_data
,),
(
key_scale_cache
,)
=
nvfp4_kv_cache_split_views
(
key_cache
)
(
nvfp4_value_data
,),
(
value_scale_cache
,)
=
nvfp4_kv_cache_split_views
(
value_cache
)
if
kv_cache_dtype
==
"nvfp4"
:
# Global scale = amax / 448 (per-tensor)
k_scale
=
(
key
.
abs
().
amax
()
/
448.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
abs
().
amax
()
/
448.0
).
to
(
torch
.
float32
)
elif
kv_scale_type
==
"tensor"
:
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
else
:
# "attn_head"
...
...
@@ -240,8 +275,9 @@ def test_reshape_and_cache_flash(
y
=
x
if
kv_cache_layout
==
"NHD"
else
x
.
permute
(
0
,
2
,
1
,
3
)
return
y
.
contiguous
()
key_cache_compact
=
permute_and_compact
(
key_cache
)
value_cache_compact
=
permute_and_compact
(
value_cache
)
if
kv_cache_dtype
!=
"nvfp4"
:
key_cache_compact
=
permute_and_compact
(
key_cache
)
value_cache_compact
=
permute_and_compact
(
value_cache
)
def
convert_fp8_local
(
output
,
input
,
scale
,
kv_dtype
):
fp8_input
=
input
.
view
(
current_platform
.
fp8_dtype
())
...
...
@@ -257,7 +293,7 @@ def test_reshape_and_cache_flash(
result
=
fp8_input
.
to
(
output
.
dtype
)
*
scale
.
view
(
1
,
-
1
,
1
,
1
)
output
.
copy_
(
result
)
# Clone the KV caches.
# Clone the KV caches
(for non-nvfp4, used as reference baseline)
.
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache_compact
,
dtype
=
torch
.
float16
)
convert_fp8_local
(
cloned_key_cache
,
key_cache_compact
,
k_scale
,
kv_cache_dtype
)
...
...
@@ -265,25 +301,27 @@ def test_reshape_and_cache_flash(
convert_fp8_local
(
cloned_value_cache
,
value_cache_compact
,
v_scale
,
kv_cache_dtype
)
el
se
:
el
if
kv_cache_dtype
!=
"nvfp4"
:
cloned_key_cache
=
key_cache_compact
.
clone
()
cloned_value_cache
=
value_cache_compact
.
clone
()
# Call the reshape_and_cache kernel.
if
implementation
==
"cuda"
:
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
,
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]),
)
if
kv_cache_dtype
!=
"nvfp4"
:
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
,
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]),
)
ops
.
reshape_and_cache_flash
(
key
,
value
,
...
...
@@ -309,6 +347,46 @@ def test_reshape_and_cache_flash(
k_scale
,
v_scale
,
)
if
kv_cache_dtype
==
"nvfp4"
:
# Verify NVFP4 by dequantizing the entire cache and comparing
# the written positions against original bf16 values.
# Same pattern as FP8: dequant whole cache, then extract and compare.
from
tests.kernels.quantization.nvfp4_utils
import
(
dequant_nvfp4_kv_cache
,
)
def
dequant_nvfp4_cache_nhd
(
data_cache
,
scale_cache
,
global_scale
):
# data_cache: [N, T, H, data_dim] NHD (contiguous inner dims)
# scale_cache: [N, T, H, scale_dim] NHD (contiguous inner dims)
# Permute to HND layout for the dequant utility.
data_hnd
=
data_cache
.
permute
(
0
,
2
,
1
,
3
)
scale_hnd
=
scale_cache
.
permute
(
0
,
2
,
1
,
3
)
result_hnd
=
dequant_nvfp4_kv_cache
(
data_hnd
,
scale_hnd
,
global_scale
,
head_size
,
block_size
)
return
result_hnd
.
permute
(
0
,
2
,
1
,
3
)
# back to [N, T, H, D]
result_key_cache
=
dequant_nvfp4_cache_nhd
(
nvfp4_key_data
,
key_scale_cache
,
k_scale
.
item
()
)
result_value_cache
=
dequant_nvfp4_cache_nhd
(
nvfp4_value_data
,
value_scale_cache
,
v_scale
.
item
()
)
# Flatten [num_blocks, block_size] → [num_slots] and index by slot_mapping.
num_slots
=
num_blocks
*
block_size
result_key_flat
=
result_key_cache
.
reshape
(
num_slots
,
num_heads
,
head_size
)
result_value_flat
=
result_value_cache
.
reshape
(
num_slots
,
num_heads
,
head_size
)
torch
.
testing
.
assert_close
(
result_key_flat
[
slot_mapping
],
key
.
float
(),
atol
=
1.5
,
rtol
=
0.5
)
torch
.
testing
.
assert_close
(
result_value_flat
[
slot_mapping
],
value
.
float
(),
atol
=
1.5
,
rtol
=
0.5
)
return
key_cache_compact
=
permute_and_compact
(
key_cache
)
value_cache_compact
=
permute_and_compact
(
value_cache
)
...
...
tests/kernels/quantization/nvfp4_utils.py
View file @
6b2b7bd0
...
...
@@ -88,6 +88,60 @@ def break_fp4_bytes(a, dtype):
return
values
.
reshape
(
m
,
n
*
2
).
to
(
dtype
=
dtype
)
def
dequant_nvfp4_kv_cache
(
fp4_data
:
torch
.
Tensor
,
block_scale
:
torch
.
Tensor
,
global_scale
:
float
,
head_size
:
int
,
block_size
:
int
,
)
->
torch
.
Tensor
:
"""Dequantize an NVFP4 KV cache with 4x4-swizzled block scales.
The input must be in HND layout so that the last two dims are
(block_size, last_dim). For NHD caches, permute to HND first.
Args:
fp4_data: [..., num_heads, block_size, head_size//2] uint8 packed fp4.
block_scale: [..., num_heads, block_size, head_size//16] fp8 block
scales (as uint8 or float8_e4m3fn).
global_scale: checkpoint dequant scale (k_scale or v_scale).
head_size: head dimension.
block_size: page size.
Returns:
[..., num_heads, block_size, head_size] float32.
"""
data_dim
=
head_size
//
2
scale_dim
=
head_size
//
16
fp4_packed
=
fp4_data
sf_swizzled
=
block_scale
.
view
(
torch
.
uint8
)
# Unswizzle 4x4 block scales on (block_size, scale_dim) plane.
# [..., T, S] → [..., T//4, 4, sg, 4] → permute → [..., T, S]
batch_shape
=
sf_swizzled
.
shape
[:
-
2
]
T
,
S
=
block_size
,
scale_dim
sg
=
S
//
4
sf_reshape
=
sf_swizzled
.
reshape
(
*
batch_shape
,
T
//
4
,
4
,
sg
,
4
)
ndim
=
sf_reshape
.
ndim
# Swap the last four dims: (..., T//4, 4, sg, 4) → (..., T//4, 4, 4, sg)
perm
=
list
(
range
(
ndim
-
4
))
+
[
ndim
-
4
,
ndim
-
1
,
ndim
-
3
,
ndim
-
2
]
sf_linear
=
sf_reshape
.
permute
(
*
perm
).
reshape
(
*
batch_shape
,
T
,
S
)
sf_f32
=
sf_linear
.
view
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
# Unpack fp4
shape
=
fp4_packed
.
shape
# [..., T, data_dim]
fp4_flat
=
fp4_packed
.
reshape
(
-
1
,
data_dim
)
fp4_vals
=
break_fp4_bytes
(
fp4_flat
,
torch
.
float32
)
fp4_vals
=
fp4_vals
.
reshape
(
*
shape
[:
-
1
],
head_size
)
# Dequant: fp4_val * block_scale * global_scale per 16-element group
return
(
fp4_vals
.
reshape
(
*
shape
[:
-
1
],
scale_dim
,
16
)
*
(
sf_f32
*
global_scale
).
unsqueeze
(
-
1
)
).
reshape
(
*
shape
[:
-
1
],
head_size
)
def
get_nvfp4_global_scale
(
a
:
torch
.
Tensor
):
return
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
abs
(
a
).
max
().
to
(
torch
.
float32
)
...
...
vllm/config/cache.py
View file @
6b2b7bd0
...
...
@@ -30,6 +30,7 @@ CacheDType = Literal[
"turboquant_3bit_nc"
,
"int8_per_token_head"
,
"fp8_per_token_head"
,
"nvfp4"
,
]
MambaDType
=
Literal
[
"auto"
,
"float32"
,
"float16"
]
MambaCacheMode
=
Literal
[
"all"
,
"align"
,
"none"
]
...
...
vllm/model_executor/layers/attention/attention.py
View file @
6b2b7bd0
...
...
@@ -387,7 +387,9 @@ class Attention(nn.Module, AttentionLayerBase):
self
.
query_quant
=
None
if
(
self
.
impl
.
supports_quant_query_input
and
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
(
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
or
self
.
kv_cache_dtype
==
"nvfp4"
)
and
not
self
.
kv_cache_dtype
.
endswith
(
"per_token_head"
)
):
is_per_head
=
(
...
...
@@ -492,7 +494,7 @@ class Attention(nn.Module, AttentionLayerBase):
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert
self
.
kv_cache_dtype
in
{
"fp8"
,
"fp8_e4m3"
}
assert
self
.
kv_cache_dtype
in
{
"fp8"
,
"fp8_e4m3"
,
"nvfp4"
}
# check if query quantization is supported
if
self
.
impl
.
supports_quant_query_input
:
...
...
vllm/utils/torch_utils.py
View file @
6b2b7bd0
...
...
@@ -46,6 +46,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"turboquant_4bit_nc"
:
torch
.
uint8
,
"turboquant_k3v4_nc"
:
torch
.
uint8
,
"turboquant_3bit_nc"
:
torch
.
uint8
,
"nvfp4"
:
torch
.
uint8
,
}
TORCH_DTYPE_TO_NUMPY_DTYPE
=
{
...
...
@@ -59,17 +60,19 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
=
{
# TODO: Add more modelopt kv cache dtype
# mappings here when it supported by some attention backend
# (for example supports nvfp4).
"fp8"
:
"fp8_e4m3"
,
"nvfp4"
:
"nvfp4"
,
}
T
=
TypeVar
(
"T"
)
def
is_quantized_kv_cache
(
kv_cache_dtype
:
str
)
->
bool
:
return
kv_cache_dtype
.
startswith
(
"fp8"
)
or
kv_cache_dtype
.
endswith
(
"per_token_head"
)
return
(
kv_cache_dtype
.
startswith
(
"fp8"
)
or
kv_cache_dtype
.
endswith
(
"per_token_head"
)
or
kv_cache_dtype
==
"nvfp4"
)
def
kv_cache_uses_per_token_head_scales
(
kv_cache_dtype
:
str
)
->
bool
:
...
...
@@ -299,6 +302,8 @@ def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None:
and
kv_algo
.
get
(
"type"
)
==
"float"
):
kv_algo
=
"fp8"
elif
kv_algo
.
get
(
"num_bits"
)
==
4
and
kv_algo
.
get
(
"type"
)
==
"float"
:
kv_algo
=
"nvfp4"
else
:
# Unknown/unsupported format - return "auto" as safe fallback
logger
.
warning
(
...
...
@@ -375,6 +380,95 @@ def set_random_seed(seed: int | None) -> None:
current_platform
.
manual_seed_all
(
seed
)
def
nvfp4_kv_cache_full_dim
(
head_size
:
int
)
->
int
:
"""Packed last dim for NVFP4 KV cache: fp4 data + fp8 block scales."""
return
head_size
//
2
+
head_size
//
16
def
_nvfp4_split_data_scale
(
kv_side
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Split a single NVFP4 KV-side buffer into data and scale views.
The input is a 4D tensor for one KV side (K or V) whose last
dimension is ``full_dim = data_dim + scale_dim``. The physical
layout within each side is [data | scale], both packed contiguously.
Args:
kv_side: 4D uint8 tensor with shape
``(num_pages, dim_1, dim_2, full_dim)``.
May be in any permutation order (NHD or HND).
Returns:
``(data, scale)`` where
``data`` is a uint8 view with shape
``(num_pages, dim_1, dim_2, data_dim)``.
``scale`` is a float8_e4m3fn view with shape
``(num_pages, dim_1, dim_2, scale_dim)``.
"""
num_pages
=
kv_side
.
shape
[
0
]
dim_1
,
dim_2
=
kv_side
.
shape
[
1
],
kv_side
.
shape
[
2
]
full_dim
=
kv_side
.
shape
[
3
]
data_dim
=
full_dim
*
8
//
9
scale_dim
=
full_dim
-
data_dim
data_per_kv
=
dim_1
*
dim_2
*
data_dim
page_bytes
=
kv_side
.
stride
(
0
)
# Derive inner strides from the kv_side strides, scaling by the
# ratio of the target dim to full_dim. This preserves the physical
# layout (NHD vs HND) encoded in the input tensor's strides.
s1
=
kv_side
.
stride
(
1
)
*
data_dim
//
full_dim
s2
=
kv_side
.
stride
(
2
)
*
data_dim
//
full_dim
data_shape
=
(
num_pages
,
dim_1
,
dim_2
,
data_dim
)
data_strides
=
(
page_bytes
,
s1
,
s2
,
1
)
s1_s
=
kv_side
.
stride
(
1
)
*
scale_dim
//
full_dim
s2_s
=
kv_side
.
stride
(
2
)
*
scale_dim
//
full_dim
scale_shape
=
(
num_pages
,
dim_1
,
dim_2
,
scale_dim
)
scale_strides
=
(
page_bytes
,
s1_s
,
s2_s
,
1
)
base
=
kv_side
.
storage_offset
()
data
=
torch
.
as_strided
(
kv_side
,
data_shape
,
data_strides
,
storage_offset
=
base
)
scale
=
torch
.
as_strided
(
kv_side
,
scale_shape
,
scale_strides
,
storage_offset
=
base
+
data_per_kv
).
view
(
torch
.
float8_e4m3fn
)
return
data
,
scale
def
nvfp4_kv_cache_split_views
(
kv_cache
:
torch
.
Tensor
)
->
tuple
[
tuple
,
tuple
]:
"""Split an NVFP4 KV cache tensor into data and scale views.
Accepts either a 5D tensor ``(num_pages, 2, dim_2, dim_3, full_dim)``
or a 4D single-side tensor ``(num_pages, dim_2, dim_3, full_dim)``.
Per-page layout: [K_data | K_scale | V_data | V_scale].
Each KV side is self-contained (data followed by its scale), so the
5D case simply splits each side independently.
The returned views are in the same dim order as the input (NHD or
HND), so callers get views matching whichever order they passed in.
Args:
kv_cache: 5D or 4D uint8 tensor where the last dimension is
``full_dim = data_dim + scale_dim = 9 * head_size / 16``.
Returns:
For 5D input:
``(k_data, v_data), (k_scale, v_scale)``
For 4D input (single KV side):
``(data,), (scale,)``
"""
if
kv_cache
.
dim
()
==
4
:
data
,
scale
=
_nvfp4_split_data_scale
(
kv_cache
)
return
(
data
,),
(
scale
,)
k_data
,
k_scale
=
_nvfp4_split_data_scale
(
kv_cache
[:,
0
])
v_data
,
v_scale
=
_nvfp4_split_data_scale
(
kv_cache
[:,
1
])
return
(
k_data
,
v_data
),
(
k_scale
,
v_scale
)
def
create_kv_caches_with_random_flash
(
num_blocks
:
int
,
block_size
:
int
,
...
...
@@ -401,15 +495,31 @@ def create_kv_caches_with_random_flash(
value_caches
:
list
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
size
=
kv_cache_allocation_shape
,
dtype
=
dtype
,
device
=
device
).
permute
(
*
stride_order
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
key_value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
"fp8"
:
_generate_random_fp8
(
key_value_cache
,
-
scale
,
scale
)
if
cache_dtype
==
"nvfp4"
:
# Full page dim: fp4 data + fp8 block scales per head.
# Per page layout: [K_data | K_scale | V_data | V_scale]
# Returns [:, 0] and [:, 1] like all other dtypes.
full_dim
=
nvfp4_kv_cache_full_dim
(
head_size
)
nvfp4_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
full_dim
)
nvfp4_phys
=
tuple
(
nvfp4_shape
[
i
]
for
i
in
stride_order
)
inv
=
[
stride_order
.
index
(
i
)
for
i
in
range
(
len
(
stride_order
))]
key_value_cache
=
torch
.
randint
(
0
,
256
,
nvfp4_phys
,
dtype
=
dtype
,
device
=
device
,
).
permute
(
*
inv
)
else
:
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
key_value_cache
=
torch
.
empty
(
size
=
kv_cache_allocation_shape
,
dtype
=
dtype
,
device
=
device
).
permute
(
*
stride_order
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
key_value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
"fp8"
:
_generate_random_fp8
(
key_value_cache
,
-
scale
,
scale
)
else
:
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
key_caches
.
append
(
key_value_cache
[:,
0
])
value_caches
.
append
(
key_value_cache
[:,
1
])
return
key_caches
,
value_caches
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
6b2b7bd0
...
...
@@ -42,7 +42,12 @@ from vllm.utils.flashinfer import (
)
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
,
is_strictly_contiguous
from
vllm.utils.torch_utils
import
(
is_quantized_kv_cache
,
is_strictly_contiguous
,
nvfp4_kv_cache_full_dim
,
nvfp4_kv_cache_split_views
,
)
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
...
...
@@ -355,6 +360,10 @@ class FlashInferBackend(AttentionBackend):
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
if
cache_dtype_str
==
"nvfp4"
:
# Packed layout: fp4 data + fp8 block scales in last dim
last_dim
=
nvfp4_kv_cache_full_dim
(
head_size
)
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
last_dim
)
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
...
...
@@ -608,11 +617,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
cache_dtype
=
self
.
cache_config
.
cache_dtype
# Cannot use self.kv_cache_spec.dtype here because kv_cache_spec
# storage dtype may not be the same as the op dtype (uint8 vs fp8_e4m3)
self
.
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
cache_dtype
)
self
.
is_kvcache_nvfp4
=
self
.
cache_dtype
==
"nvfp4"
if
self
.
is_kvcache_nvfp4
:
# For NVFP4, kv_cache_dtype stays as the string "nvfp4"
# which is passed to FlashInferImpl
self
.
kv_cache_dtype
=
self
.
cache_dtype
raise
NotImplementedError
(
"nvfp4 KV cache is not yet supported"
)
else
:
self
.
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
cache_dtype
)
else
:
self
.
cache_dtype
=
"auto"
self
.
is_kvcache_nvfp4
=
False
assert
self
.
kv_cache_spec
.
dtype
==
self
.
model_config
.
dtype
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
...
...
@@ -626,7 +643,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
can_use_trtllm
and
not
vllm_config
.
attention_config
.
disable_flashinfer_q_quantization
):
self
.
q_data_type
=
self
.
kv_cache_dtype
if
self
.
is_kvcache_nvfp4
:
# NVFP4 KV cache uses FP8 quantized queries
self
.
q_data_type
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
"fp8_e4m3"
)
else
:
self
.
q_data_type
=
self
.
kv_cache_dtype
else
:
self
.
q_data_type
=
self
.
model_config
.
dtype
...
...
@@ -1228,6 +1251,8 @@ class FlashInferImpl(AttentionImpl):
self
.
sliding_window
[
0
]
if
self
.
sliding_window
is
not
None
else
-
1
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
is_kvcache_nvfp4
=
kv_cache_dtype
==
"nvfp4"
self
.
fp4_data_dim
=
head_size
//
2
if
self
.
is_kvcache_nvfp4
else
0
self
.
logits_soft_cap
=
logits_soft_cap
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
...
...
@@ -1406,7 +1431,16 @@ class FlashInferImpl(AttentionImpl):
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
kv_cache_permute
=
kv_cache
.
permute
(
*
stride_order
)
kv_cache_permute
=
kv_cache
.
permute
(
*
stride_order
)
# HND and contiguous
# For NVFP4, the kv_cache last dim is full_dim (data + scale packed).
# Split into correctly-strided data and scale views.
nvfp4_kv_data
=
None
nvfp4_kv_block_scales
=
None
if
self
.
is_kvcache_nvfp4
:
nvfp4_kv_data
,
nvfp4_kv_block_scales
=
nvfp4_kv_cache_split_views
(
kv_cache_permute
)
use_dcp
=
self
.
dcp_world_size
>
1
...
...
@@ -1490,8 +1524,20 @@ class FlashInferImpl(AttentionImpl):
assert
self
.
o_sf_scale
is
None
out
=
output
[
num_decode_tokens
:]
if
attn_metadata
.
q_data_type
!=
FP8_DTYPE
and
is_quantized_kv_cache
(
self
.
kv_cache_dtype
prefill_kv_block_scales
=
None
if
self
.
is_kvcache_nvfp4
:
# NVFP4 trtllm-gen kernel requires FP8 query.
assert
attn_metadata
.
q_data_type
==
FP8_DTYPE
,
(
"NVFP4 KV cache requires FP8 quantized queries for "
"trtllm-gen prefill. Set "
"disable_flashinfer_q_quantization=False."
)
mock_kv_cache
=
nvfp4_kv_data
mock_block_table
=
block_tables_prefill
prefill_kv_block_scales
=
nvfp4_kv_block_scales
# noqa: F841
elif
(
attn_metadata
.
q_data_type
!=
FP8_DTYPE
and
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
):
# TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention
...
...
@@ -1636,7 +1682,9 @@ class FlashInferImpl(AttentionImpl):
trtllm_batch_decode_with_kv_cache
(
query
=
decode_query
,
kv_cache
=
kv_cache_permute
,
kv_cache
=
nvfp4_kv_data
if
self
.
is_kvcache_nvfp4
else
kv_cache_permute
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables_decode
,
seq_lens
=
seq_lens_decode
,
...
...
@@ -1667,11 +1715,13 @@ class FlashInferImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
k_cache
=
kv_cache
[:,
0
]
v_cache
=
kv_cache
[:,
1
]
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
k
v
_cache
[:,
0
]
,
k
v_cache
[:,
1
]
,
k_cache
,
v_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
...
...
vllm/v1/kv_cache_interface.py
View file @
6b2b7bd0
...
...
@@ -17,7 +17,7 @@ from vllm.logger import init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
get_dtype_size
from
vllm.utils.torch_utils
import
get_dtype_size
,
nvfp4_kv_cache_full_dim
logger
=
init_logger
(
__name__
)
...
...
@@ -38,11 +38,20 @@ class KVQuantMode(IntEnum):
FP8_PER_TENSOR
=
1
# per-tensor scales (current fp8 path)
INT8_PER_TOKEN_HEAD
=
2
# per-token-head dynamic scales for int8
FP8_PER_TOKEN_HEAD
=
3
# per-token-head dynamic scales for fp8
NVFP4
=
4
# packed fp4 data + fp8 block scales
@
property
def
is_per_token_head
(
self
)
->
bool
:
"""True for any per-token-head quantization mode."""
return
self
>=
2
return
self
in
(
KVQuantMode
.
INT8_PER_TOKEN_HEAD
,
KVQuantMode
.
FP8_PER_TOKEN_HEAD
,
)
@
property
def
is_nvfp4
(
self
)
->
bool
:
"""True for NVFP4 packed quantization mode."""
return
self
==
KVQuantMode
.
NVFP4
def
get_kv_quant_mode
(
kv_cache_dtype
:
str
)
->
KVQuantMode
:
...
...
@@ -51,7 +60,9 @@ def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
return
KVQuantMode
.
INT8_PER_TOKEN_HEAD
if
kv_cache_dtype
==
"fp8_per_token_head"
:
return
KVQuantMode
.
FP8_PER_TOKEN_HEAD
if
kv_cache_dtype
.
startswith
(
"fp8"
):
if
kv_cache_dtype
==
"nvfp4"
:
return
KVQuantMode
.
NVFP4
if
isinstance
(
kv_cache_dtype
,
str
)
and
kv_cache_dtype
.
startswith
(
"fp8"
):
return
KVQuantMode
.
FP8_PER_TENSOR
return
KVQuantMode
.
NONE
...
...
@@ -237,6 +248,19 @@ class FullAttentionSpec(AttentionSpec):
@
property
def
real_page_size_bytes
(
self
)
->
int
:
if
self
.
kv_quant_mode
.
is_nvfp4
:
# Packed layout per head: fp4 data + fp8 block scales.
# fp4 data: head_size//2 bytes (2 fp4 values per byte)
# fp8 block scale: head_size//16 bytes (1 scale per 16 elements)
last_dim
=
nvfp4_kv_cache_full_dim
(
self
.
head_size
)
+
nvfp4_kv_cache_full_dim
(
self
.
head_size_v
)
return
(
self
.
block_size
*
self
.
num_kv_heads
*
last_dim
*
get_dtype_size
(
self
.
dtype
)
)
return
(
self
.
block_size
*
self
.
num_kv_heads
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment