Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5952d8ab
Unverified
Commit
5952d8ab
authored
Mar 15, 2025
by
Lucas Wilkinson
Committed by
GitHub
Mar 15, 2025
Browse files
[Attention] Get rid of mla cache alignment (#14842)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
a2ae4965
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
83 deletions
+14
-83
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+11
-28
vllm/envs.py
vllm/envs.py
+0
-10
vllm/utils.py
vllm/utils.py
+0
-6
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+3
-39
No files found.
tests/kernels/test_cache.py
View file @
5952d8ab
...
...
@@ -8,7 +8,6 @@ import torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
align_to_256bytes
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -450,22 +449,13 @@ def _create_mla_cache(
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
device
:
str
,
align_cache
:
bool
,
)
->
torch
.
Tensor
:
cache_dtype
=
torch
.
uint8
if
kv_cache_dtype
==
"fp8"
else
dtype
if
align_cache
:
alloc_entry_size
=
align_to_256bytes
(
entry_size
,
cache_dtype
)
alloc_shape
=
(
num_blocks
,
block_size
,
alloc_entry_size
)
cache_full
=
torch
.
zeros
(
alloc_shape
,
dtype
=
cache_dtype
,
device
=
device
)
cache
=
cache_full
[...,
:
entry_size
]
else
:
cache
=
torch
.
zeros
(
num_blocks
,
return
torch
.
zeros
(
num_blocks
,
block_size
,
entry_size
,
dtype
=
cache_dtype
,
device
=
device
)
return
cache
def
_fill_mla_cache
(
cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
):
...
...
@@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
])
@
torch
.
inference_mode
()
def
test_concat_and_cache_mla
(
kv_lora_rank
:
int
,
...
...
@@ -500,7 +489,6 @@ def test_concat_and_cache_mla(
seed
:
int
,
device
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -520,7 +508,7 @@ def test_concat_and_cache_mla(
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
ref_temp
=
torch
.
zeros
(
*
kv_cache
.
shape
,
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
num_tokens
):
...
...
@@ -576,7 +564,6 @@ def test_concat_and_cache_mla(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
,
True
])
@
torch
.
inference_mode
()
def
test_copy_blocks_mla
(
kv_lora_rank
:
int
,
...
...
@@ -588,7 +575,6 @@ def test_copy_blocks_mla(
seed
:
int
,
device
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -598,7 +584,7 @@ def test_copy_blocks_mla(
kv_caches
=
[]
for
_
in
range
(
num_layers
):
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
kv_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
kv_caches
.
append
(
kv_cache
)
...
...
@@ -642,7 +628,6 @@ def test_copy_blocks_mla(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
,
True
])
@
torch
.
inference_mode
()
def
test_swap_blocks_mla
(
kv_lora_rank
:
int
,
...
...
@@ -653,7 +638,6 @@ def test_swap_blocks_mla(
seed
:
int
,
device
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -661,9 +645,9 @@ def test_swap_blocks_mla(
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
dst_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
)
_fill_mla_cache
(
dst_cache
,
kv_cache_dtype
)
...
...
@@ -704,15 +688,14 @@ def test_swap_blocks_mla(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
])
# You can also test "fp8" if needed.
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
align_cache
,
device
):
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
seq_len_tensor
=
torch
.
randint
(
0
,
...
...
vllm/envs.py
View file @
5952d8ab
...
...
@@ -84,7 +84,6 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE
:
bool
=
False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
:
int
=
128
VLLM_MLA_DISABLE
:
bool
=
False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE
:
bool
=
True
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
...
...
@@ -580,15 +579,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_RAY_BUNDLE_INDICES"
:
lambda
:
os
.
getenv
(
"VLLM_RAY_BUNDLE_INDICES"
,
""
),
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
# byte aligned for better performance, this increases the memory usage of
# the cache. Currently this only affects MLA that results in non-256
# byte aligned entries. This matches the alignment the CUDA runtime uses
# for all allocations. Currently this primarily affects MLA, for most other
# models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_CUDA_MEM_ALIGN_KV_CACHE"
,
"1"
))),
# In some system, find_loaded_library() may not work. So we allow users to
# specify the path through environment variable VLLM_CUDART_SO_PATH.
"VLLM_CUDART_SO_PATH"
:
...
...
vllm/utils.py
View file @
5952d8ab
...
...
@@ -827,12 +827,6 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
def
align_to_256bytes
(
extent
:
int
,
dtype
:
torch
.
dtype
)
->
int
:
dtype_size
=
get_dtype_size
(
dtype
)
eles_per_256bytes
=
256
//
dtype_size
return
round_up
(
extent
,
eles_per_256bytes
)
# `collections` helpers
def
is_list_of
(
value
:
object
,
...
...
vllm/worker/cache_engine.py
View file @
5952d8ab
# SPDX-License-Identifier: Apache-2.0
"""CacheEngine class for managing the KV cache."""
from
math
import
prod
from
typing
import
List
import
torch
from
vllm
import
envs
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
align_to_256bytes
,
get_dtype_size
,
is_pin_memory_available
)
get_dtype_size
,
is_pin_memory_available
)
logger
=
init_logger
(
__name__
)
...
...
@@ -42,7 +38,6 @@ class CacheEngine:
self
.
num_attention_layers
=
model_config
.
get_num_layers_by_block_type
(
parallel_config
,
LayerBlockType
.
attention
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
align_cache
=
self
.
_align_cache
(
model_config
)
self
.
block_size
=
cache_config
.
block_size
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
...
...
@@ -81,38 +76,18 @@ class CacheEngine:
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
# Align entries so they are 256 byte aligned for better performance
# Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.
if
self
.
align_cache
:
# We assume the cache shape is:
# (TOTAL_PAGES, PAGE_SIZE, entry_shape...)
# NOTE this assumption currently only holds for MLA so we only apply
# this optimization when `use_mla` is true
entry_shape
=
kv_cache_shape
[
2
:]
entry_size
=
prod
(
entry_shape
)
alloc_entry_size
=
align_to_256bytes
(
entry_size
,
self
.
dtype
)
alloc_shape
=
(
*
kv_cache_shape
[:
2
],
alloc_entry_size
)
else
:
alloc_shape
=
kv_cache_shape
for
_
in
range
(
self
.
num_attention_layers
):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache
=
torch
.
zeros
(
alloc
_shape
,
layer_kv_cache
=
torch
.
zeros
(
kv_cache
_shape
,
dtype
=
self
.
dtype
,
pin_memory
=
pin_memory
,
device
=
device
)
# If we allocated with padding for alignment reasons truncate the
# shape while preserving the aligned stride
if
self
.
align_cache
:
layer_kv_cache
=
layer_kv_cache
[...,
:
entry_size
]
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache
.
append
(
layer_kv_cache
.
view
(
kv_cache_shape
)
)
kv_cache
.
append
(
layer_kv_cache
)
return
kv_cache
def
swap_in
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
...
...
@@ -128,14 +103,6 @@ class CacheEngine:
def
copy
(
self
,
src_to_dsts
:
torch
.
Tensor
)
->
None
:
self
.
attn_backend
.
copy_blocks
(
self
.
gpu_cache
,
src_to_dsts
)
@
staticmethod
def
_align_cache
(
model_config
:
ModelConfig
):
# Currently align_cache only applies to MLA models since the other
# cache kernels haven't been updated yet to support non-continguous
# tensors
return
model_config
.
use_mla
and
current_platform
.
is_cuda
()
\
and
envs
.
VLLM_CUDA_MEM_ALIGN_KV_CACHE
@
staticmethod
def
get_cache_block_size
(
cache_config
:
CacheConfig
,
...
...
@@ -153,9 +120,6 @@ class CacheEngine:
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
key_cache_entry
=
num_heads
*
head_size
if
CacheEngine
.
_align_cache
(
model_config
):
key_cache_entry
=
align_to_256bytes
(
key_cache_entry
,
model_config
.
dtype
)
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment