Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e96f56e
Unverified
Commit
9e96f56e
authored
Apr 26, 2025
by
Shu Wang
Committed by
GitHub
Apr 25, 2025
Browse files
Allocate kv_cache with stride order (#16605)
Signed-off-by:
shuw
<
shuw@nvidia.com
>
parent
b2789112
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
119 additions
and
50 deletions
+119
-50
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+21
-18
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+41
-19
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+4
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+25
-5
vllm/utils.py
vllm/utils.py
+10
-3
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+18
-5
No files found.
csrc/cache_kernels.cu
View file @
9e96f56e
...
...
@@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t
*
__restrict__
value_cache
,
// [num_blocks, block_size, num_heads,
// head_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
const
int64_t
block_stride
,
const
int64_t
page_stride
,
const
int64_t
head_stride
,
const
int64_t
key_stride
,
const
int64_t
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
// NOTE: slot_idx can be -1 if the token is padded
...
...
@@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int64_t
tgt_key_value_idx
=
block_idx
*
block_stride
+
block_offset
*
num_heads
*
head_siz
e
+
head_idx
*
head_s
iz
e
+
head_offset
;
block_offset
*
page_strid
e
+
head_idx
*
head_s
trid
e
+
head_offset
;
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
...
...
@@ -396,16 +397,16 @@ void reshape_and_cache(
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride,
key
_stride, \
value_stride, num_heads, head_size,
block_size,
\
reinterpret_cast<const float*>(k_scale.data_ptr()),
\
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE)
\
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE>
\
<<<grid, block, 0, stream>>>(
\
reinterpret_cast<KV_T*>(key.data_ptr()),
\
reinterpret_cast<KV_T*>(value.data_ptr()),
\
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),
\
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),
\
slot_mapping.data_ptr<int64_t>(), block_stride,
page
_stride,
\
head_stride, key_stride,
value_stride, num_heads, head_size, \
block_size,
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void
reshape_and_cache_flash
(
...
...
@@ -432,9 +433,11 @@ void reshape_and_cache_flash(
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
1
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
int
block_stride
=
key_cache
.
stride
(
0
);
int64_t
key_stride
=
key
.
stride
(
0
);
int64_t
value_stride
=
value
.
stride
(
0
);
int64_t
block_stride
=
key_cache
.
stride
(
0
);
int64_t
page_stride
=
key_cache
.
stride
(
1
);
int64_t
head_stride
=
key_cache
.
stride
(
2
);
TORCH_CHECK
(
key_cache
.
stride
(
0
)
==
value_cache
.
stride
(
0
));
dim3
grid
(
num_tokens
);
...
...
tests/kernels/attention/test_cache.py
View file @
9e96f56e
...
...
@@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
120
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
CACHE_LAYOUTS
=
[
"NHD"
,
"HND"
]
# Parameters for MLA tests.
KV_LORA_RANKS
=
[
512
]
...
...
@@ -220,6 +221,7 @@ def test_reshape_and_cache(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_layout"
,
CACHE_LAYOUTS
)
@
torch
.
inference_mode
()
def
test_reshape_and_cache_flash
(
kv_cache_factory_flashinfer
,
...
...
@@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
seed
:
int
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_layout
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens
=
num_tokens
//
2
num_blocks
=
num_blocks
//
2
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
...
...
@@ -261,27 +267,35 @@ def test_reshape_and_cache_flash(
kv_cache_dtype
,
dtype
,
device
=
device
,
cache_layout
=
kv_cache_layout
,
)
key_cache
,
value_cache
=
key_caches
[
0
].
contiguous
(
),
value_caches
[
0
].
contiguous
()
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
del
key_caches
del
value_caches
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
def
permute_and_compact
(
x
):
y
=
x
if
kv_cache_layout
==
"NHD"
else
x
.
permute
(
0
,
2
,
1
,
3
)
return
y
.
contiguous
()
key_cache_compact
=
permute_and_compact
(
key_cache
)
value_cache_compact
=
permute_and_compact
(
value_cache
)
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
.
item
(),
kv_cache_dtype
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
.
item
(),
cloned_key_cache
=
torch
.
empty_like
(
key_cache_compact
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache_compact
,
k_scale
.
item
(),
kv_cache_dtype
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache_compact
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache_compact
,
v_scale
.
item
(),
kv_cache_dtype
)
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
cloned_key_cache
=
key_cache_compact
.
clone
()
cloned_value_cache
=
value_cache_compact
.
clone
()
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
...
@@ -289,16 +303,20 @@ def test_reshape_and_cache_flash(
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
key_cache_compact
=
permute_and_compact
(
key_cache
)
value_cache_compact
=
permute_and_compact
(
value_cache
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache_compact
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
,
key_cache
_compact
,
k_scale
.
item
(),
kv_dtype
=
kv_cache_dtype
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
result_value_cache
=
torch
.
empty_like
(
value_cache_compact
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
,
value_cache
_compact
,
v_scale
.
item
(),
kv_dtype
=
kv_cache_dtype
)
...
...
@@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
if
kv_cache_layout
==
"NHD"
:
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
else
:
cloned_key_cache
[
block_idx
,
:,
block_offset
,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
:,
block_offset
,
:]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
torch
.
testing
.
assert_close
(
result_key_cache
,
...
...
@@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
atol
=
0.001
,
rtol
=
0.1
)
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_close
(
key_cache
_compact
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
_compact
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
...
vllm/attention/backends/abstract.py
View file @
9e96f56e
...
...
@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
)
->
Tuple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
swap_blocks
(
...
...
vllm/attention/backends/flashinfer.py
View file @
9e96f56e
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
import
os
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
...
...
@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
"NHD"
).
upper
()
class
FlashInferBackend
(
AttentionBackend
):
...
...
@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
assert
(
cache_layout
in
(
"NHD"
,
"HND"
))
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
return
stride_order
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
...
...
@@ -188,6 +200,7 @@ class FlashInferState(AttentionState):
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
self
.
runner
.
vllm_config
self
.
_kv_cache_layout
=
None
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
...
...
@@ -197,10 +210,15 @@ class FlashInferState(AttentionState):
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
def
get_kv_cache_layout
(
self
):
if
self
.
_kv_cache_layout
is
None
:
self
.
_kv_cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
return
self
.
_kv_cache_layout
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
self
.
_get_workspace_buffer
(),
self
.
get_kv_cache_layout
()
)
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
...
...
@@ -213,7 +231,7 @@ class FlashInferState(AttentionState):
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
,
self
.
get_kv_cache_layout
()
,
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
...
...
@@ -274,7 +292,8 @@ class FlashInferState(AttentionState):
self
.
_graph_decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
self
.
get_kv_cache_layout
(),
use_tensor_cores
)
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
...
...
@@ -1005,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# when kv_cache is not provided.
...
...
@@ -1036,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
...
...
@@ -1051,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
...
...
vllm/utils.py
View file @
9e96f56e
...
...
@@ -765,21 +765,28 @@ def create_kv_caches_with_random_flash(
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
Optional
[
int
]
=
None
,
device
:
Optional
[
str
]
=
"cuda"
,
cache_layout
:
Optional
[
str
]
=
"NHD"
,
)
->
tuple
[
list
[
torch
.
Tensor
],
list
[
torch
.
Tensor
]]:
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
generic_kv_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
assert
cache_layout
in
(
"NHD"
,
"HND"
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
kv_cache_allocation_shape
=
tuple
(
generic_kv_cache_shape
[
i
]
for
i
in
stride_order
)
scale
=
head_size
**-
0.5
key_caches
:
list
[
torch
.
Tensor
]
=
[]
value_caches
:
list
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
size
=
k
ey_value_cache
_shape
,
key_value_cache
=
torch
.
empty
(
size
=
k
v_cache_allocation
_shape
,
dtype
=
torch_dtype
,
device
=
device
)
device
=
device
)
.
permute
(
*
stride_order
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
key_value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8'
:
...
...
vllm/worker/cache_engine.py
View file @
9e96f56e
...
...
@@ -71,19 +71,32 @@ class CacheEngine:
device
:
str
,
)
->
List
[
torch
.
Tensor
]:
"""Allocates KV cache on the specified device."""
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_
generic_
shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_kv_heads
,
self
.
head_size
)
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
try
:
kv_cache_stride_order
=
self
.
attn_backend
.
get_kv_cache_stride_order
(
)
except
(
AttributeError
,
NotImplementedError
):
kv_cache_stride_order
=
tuple
(
range
(
len
(
kv_cache_generic_shape
)))
# The allocation respects the backend-defined stride order to ensure
# the semantic remains consistent for each backend. We first obtain the
# generic kv cache shape and then permute it according to the stride
# order which could result in a non-contiguous tensor.
kv_cache_allocation_shape
=
tuple
(
kv_cache_generic_shape
[
i
]
for
i
in
kv_cache_stride_order
)
for
_
in
range
(
self
.
num_attention_layers
):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
self
.
dtype
,
pin_memory
=
pin_memory
,
device
=
device
)
layer_kv_cache
=
torch
.
zeros
(
kv_cache_allocation_shape
,
dtype
=
self
.
dtype
,
pin_memory
=
pin_memory
,
device
=
device
).
permute
(
*
kv_cache_stride_order
)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment