Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e96f56e
Unverified
Commit
9e96f56e
authored
Apr 26, 2025
by
Shu Wang
Committed by
GitHub
Apr 25, 2025
Browse files
Allocate kv_cache with stride order (#16605)
Signed-off-by:
shuw
<
shuw@nvidia.com
>
parent
b2789112
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
119 additions
and
50 deletions
+119
-50
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+21
-18
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+41
-19
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+4
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+25
-5
vllm/utils.py
vllm/utils.py
+10
-3
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+18
-5
No files found.
csrc/cache_kernels.cu
View file @
9e96f56e
...
@@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
...
@@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t
*
__restrict__
value_cache
,
// [num_blocks, block_size, num_heads,
cache_t
*
__restrict__
value_cache
,
// [num_blocks, block_size, num_heads,
// head_size]
// head_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
const
int
key_stride
,
const
int
value_stride
,
const
int64_t
block_stride
,
const
int64_t
page_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
int64_t
head_stride
,
const
int64_t
key_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
const
int64_t
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
// NOTE: slot_idx can be -1 if the token is padded
// NOTE: slot_idx can be -1 if the token is padded
...
@@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
...
@@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
const
int
head_idx
=
i
/
head_size
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int64_t
tgt_key_value_idx
=
block_idx
*
block_stride
+
const
int64_t
tgt_key_value_idx
=
block_idx
*
block_stride
+
block_offset
*
num_heads
*
head_siz
e
+
block_offset
*
page_strid
e
+
head_idx
*
head_s
iz
e
+
head_offset
;
head_idx
*
head_s
trid
e
+
head_offset
;
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
...
@@ -396,16 +397,16 @@ void reshape_and_cache(
...
@@ -396,16 +397,16 @@ void reshape_and_cache(
// KV_T is the data type of key and value tensors.
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE)
\
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE>
\
<<<grid, block, 0, stream>>>( \
<<<grid, block, 0, stream>>>(
\
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(key.data_ptr()),
\
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()),
\
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),
\
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),
\
slot_mapping.data_ptr<int64_t>(), block_stride,
key
_stride, \
slot_mapping.data_ptr<int64_t>(), block_stride,
page
_stride,
\
value_stride, num_heads, head_size,
block_size,
\
head_stride, key_stride,
value_stride, num_heads, head_size, \
reinterpret_cast<const float*>(k_scale.data_ptr()),
\
block_size,
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
reinterpret_cast<const float*>(v_scale.data_ptr()));
void
reshape_and_cache_flash
(
void
reshape_and_cache_flash
(
...
@@ -432,9 +433,11 @@ void reshape_and_cache_flash(
...
@@ -432,9 +433,11 @@ void reshape_and_cache_flash(
int
head_size
=
key
.
size
(
2
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
1
);
int
block_size
=
key_cache
.
size
(
1
);
int
key_stride
=
key
.
stride
(
0
);
int64_t
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
int64_t
value_stride
=
value
.
stride
(
0
);
int
block_stride
=
key_cache
.
stride
(
0
);
int64_t
block_stride
=
key_cache
.
stride
(
0
);
int64_t
page_stride
=
key_cache
.
stride
(
1
);
int64_t
head_stride
=
key_cache
.
stride
(
2
);
TORCH_CHECK
(
key_cache
.
stride
(
0
)
==
value_cache
.
stride
(
0
));
TORCH_CHECK
(
key_cache
.
stride
(
0
)
==
value_cache
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
...
...
tests/kernels/attention/test_cache.py
View file @
9e96f56e
...
@@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing
...
@@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
120
,
256
]
HEAD_SIZES
=
[
64
,
80
,
120
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
CACHE_LAYOUTS
=
[
"NHD"
,
"HND"
]
# Parameters for MLA tests.
# Parameters for MLA tests.
KV_LORA_RANKS
=
[
512
]
KV_LORA_RANKS
=
[
512
]
...
@@ -220,6 +221,7 @@ def test_reshape_and_cache(
...
@@ -220,6 +221,7 @@ def test_reshape_and_cache(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_layout"
,
CACHE_LAYOUTS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_reshape_and_cache_flash
(
def
test_reshape_and_cache_flash
(
kv_cache_factory_flashinfer
,
kv_cache_factory_flashinfer
,
...
@@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
...
@@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_cache_layout
:
str
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens
=
num_tokens
//
2
num_blocks
=
num_blocks
//
2
# Create a random slot mapping.
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
num_slots
=
block_size
*
num_blocks
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
qkv
=
torch
.
randn
(
num_tokens
,
qkv
=
torch
.
randn
(
num_tokens
,
3
,
3
,
num_heads
,
num_heads
,
...
@@ -261,27 +267,35 @@ def test_reshape_and_cache_flash(
...
@@ -261,27 +267,35 @@ def test_reshape_and_cache_flash(
kv_cache_dtype
,
kv_cache_dtype
,
dtype
,
dtype
,
device
=
device
,
device
=
device
,
cache_layout
=
kv_cache_layout
,
)
)
key_cache
,
value_cache
=
key_caches
[
0
].
contiguous
(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
),
value_caches
[
0
].
contiguous
()
del
key_caches
del
key_caches
del
value_caches
del
value_caches
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
def
permute_and_compact
(
x
):
y
=
x
if
kv_cache_layout
==
"NHD"
else
x
.
permute
(
0
,
2
,
1
,
3
)
return
y
.
contiguous
()
key_cache_compact
=
permute_and_compact
(
key_cache
)
value_cache_compact
=
permute_and_compact
(
value_cache
)
# Clone the KV caches.
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cloned_key_cache
=
torch
.
empty_like
(
key_cache_compact
,
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
.
item
(),
dtype
=
torch
.
float16
)
kv_cache_dtype
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache_compact
,
k_scale
.
item
(),
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
.
item
(),
kv_cache_dtype
)
kv_cache_dtype
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache_compact
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache_compact
,
v_scale
.
item
(),
kv_cache_dtype
)
else
:
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_key_cache
=
key_cache_compact
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache_compact
.
clone
()
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
@@ -289,16 +303,20 @@ def test_reshape_and_cache_flash(
...
@@ -289,16 +303,20 @@ def test_reshape_and_cache_flash(
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
key_cache_compact
=
permute_and_compact
(
key_cache
)
value_cache_compact
=
permute_and_compact
(
value_cache
)
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache_compact
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
ops
.
convert_fp8
(
result_key_cache
,
key_cache
,
key_cache
_compact
,
k_scale
.
item
(),
k_scale
.
item
(),
kv_dtype
=
kv_cache_dtype
)
kv_dtype
=
kv_cache_dtype
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
result_value_cache
=
torch
.
empty_like
(
value_cache_compact
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
ops
.
convert_fp8
(
result_value_cache
,
value_cache
,
value_cache
_compact
,
v_scale
.
item
(),
v_scale
.
item
(),
kv_dtype
=
kv_cache_dtype
)
kv_dtype
=
kv_cache_dtype
)
...
@@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
...
@@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
if
kv_cache_layout
==
"NHD"
:
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
else
:
cloned_key_cache
[
block_idx
,
:,
block_offset
,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
:,
block_offset
,
:]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
torch
.
testing
.
assert_close
(
result_key_cache
,
torch
.
testing
.
assert_close
(
result_key_cache
,
...
@@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
...
@@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
atol
=
0.001
,
atol
=
0.001
,
rtol
=
0.1
)
rtol
=
0.1
)
else
:
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
key_cache
_compact
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_close
(
value_cache
_compact
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
...
vllm/attention/backends/abstract.py
View file @
9e96f56e
...
@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
...
@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
swap_blocks
(
def
swap_blocks
(
...
...
vllm/attention/backends/flashinfer.py
View file @
9e96f56e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
import
dataclasses
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -48,6 +49,9 @@ if TYPE_CHECKING:
...
@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
"NHD"
).
upper
()
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
...
@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
...
@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
assert
(
cache_layout
in
(
"NHD"
,
"HND"
))
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
return
stride_order
@
staticmethod
@
staticmethod
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
...
@@ -188,6 +200,7 @@ class FlashInferState(AttentionState):
...
@@ -188,6 +200,7 @@ class FlashInferState(AttentionState):
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
self
.
runner
.
vllm_config
self
.
vllm_config
=
self
.
runner
.
vllm_config
self
.
_kv_cache_layout
=
None
def
_get_workspace_buffer
(
self
):
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
if
self
.
_workspace_buffer
is
None
:
...
@@ -197,10 +210,15 @@ class FlashInferState(AttentionState):
...
@@ -197,10 +210,15 @@ class FlashInferState(AttentionState):
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
return
self
.
_workspace_buffer
def
get_kv_cache_layout
(
self
):
if
self
.
_kv_cache_layout
is
None
:
self
.
_kv_cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
return
self
.
_kv_cache_layout
def
_get_prefill_wrapper
(
self
):
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
self
.
_get_workspace_buffer
(),
self
.
get_kv_cache_layout
()
)
return
self
.
_prefill_wrapper
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
def
_get_decode_wrapper
(
self
):
...
@@ -213,7 +231,7 @@ class FlashInferState(AttentionState):
...
@@ -213,7 +231,7 @@ class FlashInferState(AttentionState):
num_qo_heads
//
num_kv_heads
>
4
)
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
self
.
_get_workspace_buffer
(),
"NHD"
,
self
.
get_kv_cache_layout
()
,
use_tensor_cores
=
use_tensor_cores
)
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
return
self
.
_decode_wrapper
...
@@ -274,7 +292,8 @@ class FlashInferState(AttentionState):
...
@@ -274,7 +292,8 @@ class FlashInferState(AttentionState):
self
.
_graph_decode_wrapper
=
\
self
.
_graph_decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
self
.
get_kv_cache_layout
(),
use_tensor_cores
)
use_tensor_cores
)
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
...
@@ -1005,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1005,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# We will use flash attention for prefill
# when kv_cache is not provided.
# when kv_cache is not provided.
...
@@ -1036,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1036,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
query
,
query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
)
...
@@ -1051,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1051,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_query
,
decode_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
)
...
...
vllm/utils.py
View file @
9e96f56e
...
@@ -765,21 +765,28 @@ def create_kv_caches_with_random_flash(
...
@@ -765,21 +765,28 @@ def create_kv_caches_with_random_flash(
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
int
]
=
None
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
Optional
[
str
]
=
"cuda"
,
cache_layout
:
Optional
[
str
]
=
"NHD"
,
)
->
tuple
[
list
[
torch
.
Tensor
],
list
[
torch
.
Tensor
]]:
)
->
tuple
[
list
[
torch
.
Tensor
],
list
[
torch
.
Tensor
]]:
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
generic_kv_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
assert
cache_layout
in
(
"NHD"
,
"HND"
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
kv_cache_allocation_shape
=
tuple
(
generic_kv_cache_shape
[
i
]
for
i
in
stride_order
)
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
key_caches
:
list
[
torch
.
Tensor
]
=
[]
key_caches
:
list
[
torch
.
Tensor
]
=
[]
value_caches
:
list
[
torch
.
Tensor
]
=
[]
value_caches
:
list
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
size
=
k
ey_value_cache
_shape
,
key_value_cache
=
torch
.
empty
(
size
=
k
v_cache_allocation
_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
device
=
device
)
device
=
device
)
.
permute
(
*
stride_order
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
key_value_cache
.
uniform_
(
-
scale
,
scale
)
key_value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8'
:
elif
cache_dtype
==
'fp8'
:
...
...
vllm/worker/cache_engine.py
View file @
9e96f56e
...
@@ -71,19 +71,32 @@ class CacheEngine:
...
@@ -71,19 +71,32 @@ class CacheEngine:
device
:
str
,
device
:
str
,
)
->
List
[
torch
.
Tensor
]:
)
->
List
[
torch
.
Tensor
]:
"""Allocates KV cache on the specified device."""
"""Allocates KV cache on the specified device."""
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
kv_cache_
generic_
shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_kv_heads
,
self
.
head_size
)
num_blocks
,
self
.
block_size
,
self
.
num_kv_heads
,
self
.
head_size
)
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
try
:
kv_cache_stride_order
=
self
.
attn_backend
.
get_kv_cache_stride_order
(
)
except
(
AttributeError
,
NotImplementedError
):
kv_cache_stride_order
=
tuple
(
range
(
len
(
kv_cache_generic_shape
)))
# The allocation respects the backend-defined stride order to ensure
# the semantic remains consistent for each backend. We first obtain the
# generic kv cache shape and then permute it according to the stride
# order which could result in a non-contiguous tensor.
kv_cache_allocation_shape
=
tuple
(
kv_cache_generic_shape
[
i
]
for
i
in
kv_cache_stride_order
)
for
_
in
range
(
self
.
num_attention_layers
):
for
_
in
range
(
self
.
num_attention_layers
):
# null block in CpuGpuBlockAllocator requires at least that
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# block to be zeroed-out.
# We zero-out everything for simplicity.
# We zero-out everything for simplicity.
layer_kv_cache
=
torch
.
zeros
(
kv_cache_shape
,
layer_kv_cache
=
torch
.
zeros
(
dtype
=
self
.
dtype
,
kv_cache_allocation_shape
,
pin_memory
=
pin_memory
,
dtype
=
self
.
dtype
,
device
=
device
)
pin_memory
=
pin_memory
,
device
=
device
).
permute
(
*
kv_cache_stride_order
)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
# when entry_shape is higher than 1D
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment