Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5952d8ab
Unverified
Commit
5952d8ab
authored
Mar 15, 2025
by
Lucas Wilkinson
Committed by
GitHub
Mar 15, 2025
Browse files
[Attention] Get rid of mla cache alignment (#14842)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
a2ae4965
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
83 deletions
+14
-83
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+11
-28
vllm/envs.py
vllm/envs.py
+0
-10
vllm/utils.py
vllm/utils.py
+0
-6
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+3
-39
No files found.
tests/kernels/test_cache.py
View file @
5952d8ab
...
@@ -8,7 +8,6 @@ import torch
...
@@ -8,7 +8,6 @@ import torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
align_to_256bytes
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -450,22 +449,13 @@ def _create_mla_cache(
...
@@ -450,22 +449,13 @@ def _create_mla_cache(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
align_cache
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
cache_dtype
=
torch
.
uint8
if
kv_cache_dtype
==
"fp8"
else
dtype
cache_dtype
=
torch
.
uint8
if
kv_cache_dtype
==
"fp8"
else
dtype
return
torch
.
zeros
(
num_blocks
,
if
align_cache
:
alloc_entry_size
=
align_to_256bytes
(
entry_size
,
cache_dtype
)
alloc_shape
=
(
num_blocks
,
block_size
,
alloc_entry_size
)
cache_full
=
torch
.
zeros
(
alloc_shape
,
dtype
=
cache_dtype
,
device
=
device
)
cache
=
cache_full
[...,
:
entry_size
]
else
:
cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
block_size
,
entry_size
,
entry_size
,
dtype
=
cache_dtype
,
dtype
=
cache_dtype
,
device
=
device
)
device
=
device
)
return
cache
def
_fill_mla_cache
(
cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
):
def
_fill_mla_cache
(
cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
):
...
@@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
...
@@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_concat_and_cache_mla
(
def
test_concat_and_cache_mla
(
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
...
@@ -500,7 +489,6 @@ def test_concat_and_cache_mla(
...
@@ -500,7 +489,6 @@ def test_concat_and_cache_mla(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -520,7 +508,7 @@ def test_concat_and_cache_mla(
...
@@ -520,7 +508,7 @@ def test_concat_and_cache_mla(
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
ref_temp
=
torch
.
zeros
(
*
kv_cache
.
shape
,
dtype
=
dtype
,
device
=
device
)
ref_temp
=
torch
.
zeros
(
*
kv_cache
.
shape
,
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
...
@@ -576,7 +564,6 @@ def test_concat_and_cache_mla(
...
@@ -576,7 +564,6 @@ def test_concat_and_cache_mla(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
,
True
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_copy_blocks_mla
(
def
test_copy_blocks_mla
(
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
...
@@ -588,7 +575,6 @@ def test_copy_blocks_mla(
...
@@ -588,7 +575,6 @@ def test_copy_blocks_mla(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -598,7 +584,7 @@ def test_copy_blocks_mla(
...
@@ -598,7 +584,7 @@ def test_copy_blocks_mla(
kv_caches
=
[]
kv_caches
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
kv_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
_fill_mla_cache
(
kv_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
kv_caches
.
append
(
kv_cache
)
kv_caches
.
append
(
kv_cache
)
...
@@ -642,7 +628,6 @@ def test_copy_blocks_mla(
...
@@ -642,7 +628,6 @@ def test_copy_blocks_mla(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
,
True
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_swap_blocks_mla
(
def
test_swap_blocks_mla
(
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
...
@@ -653,7 +638,6 @@ def test_swap_blocks_mla(
...
@@ -653,7 +638,6 @@ def test_swap_blocks_mla(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -661,9 +645,9 @@ def test_swap_blocks_mla(
...
@@ -661,9 +645,9 @@ def test_swap_blocks_mla(
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
dst_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
dst_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
)
_fill_mla_cache
(
dst_cache
,
kv_cache_dtype
)
_fill_mla_cache
(
dst_cache
,
kv_cache_dtype
)
...
@@ -704,15 +688,14 @@ def test_swap_blocks_mla(
...
@@ -704,15 +688,14 @@ def test_swap_blocks_mla(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
])
# You can also test "fp8" if needed.
[
"auto"
])
# You can also test "fp8" if needed.
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
def
test_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
align_cache
,
device
):
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
seq_len_tensor
=
torch
.
randint
(
0
,
seq_len_tensor
=
torch
.
randint
(
0
,
...
...
vllm/envs.py
View file @
5952d8ab
...
@@ -84,7 +84,6 @@ if TYPE_CHECKING:
...
@@ -84,7 +84,6 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE
:
bool
=
False
VLLM_SERVER_DEV_MODE
:
bool
=
False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
:
int
=
128
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
:
int
=
128
VLLM_MLA_DISABLE
:
bool
=
False
VLLM_MLA_DISABLE
:
bool
=
False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE
:
bool
=
True
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
...
@@ -580,15 +579,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -580,15 +579,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_RAY_BUNDLE_INDICES"
:
"VLLM_RAY_BUNDLE_INDICES"
:
lambda
:
os
.
getenv
(
"VLLM_RAY_BUNDLE_INDICES"
,
""
),
lambda
:
os
.
getenv
(
"VLLM_RAY_BUNDLE_INDICES"
,
""
),
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
# byte aligned for better performance, this increases the memory usage of
# the cache. Currently this only affects MLA that results in non-256
# byte aligned entries. This matches the alignment the CUDA runtime uses
# for all allocations. Currently this primarily affects MLA, for most other
# models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_CUDA_MEM_ALIGN_KV_CACHE"
,
"1"
))),
# In some system, find_loaded_library() may not work. So we allow users to
# In some system, find_loaded_library() may not work. So we allow users to
# specify the path through environment variable VLLM_CUDART_SO_PATH.
# specify the path through environment variable VLLM_CUDART_SO_PATH.
"VLLM_CUDART_SO_PATH"
:
"VLLM_CUDART_SO_PATH"
:
...
...
vllm/utils.py
View file @
5952d8ab
...
@@ -827,12 +827,6 @@ def get_dtype_size(dtype: torch.dtype) -> int:
...
@@ -827,12 +827,6 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
def
align_to_256bytes
(
extent
:
int
,
dtype
:
torch
.
dtype
)
->
int
:
dtype_size
=
get_dtype_size
(
dtype
)
eles_per_256bytes
=
256
//
dtype_size
return
round_up
(
extent
,
eles_per_256bytes
)
# `collections` helpers
# `collections` helpers
def
is_list_of
(
def
is_list_of
(
value
:
object
,
value
:
object
,
...
...
vllm/worker/cache_engine.py
View file @
5952d8ab
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""CacheEngine class for managing the KV cache."""
"""CacheEngine class for managing the KV cache."""
from
math
import
prod
from
typing
import
List
from
typing
import
List
import
torch
import
torch
from
vllm
import
envs
from
vllm.attention
import
get_attn_backend
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
from
vllm.config
import
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
align_to_256bytes
,
get_dtype_size
,
get_dtype_size
,
is_pin_memory_available
)
is_pin_memory_available
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -42,7 +38,6 @@ class CacheEngine:
...
@@ -42,7 +38,6 @@ class CacheEngine:
self
.
num_attention_layers
=
model_config
.
get_num_layers_by_block_type
(
self
.
num_attention_layers
=
model_config
.
get_num_layers_by_block_type
(
parallel_config
,
LayerBlockType
.
attention
)
parallel_config
,
LayerBlockType
.
attention
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
align_cache
=
self
.
_align_cache
(
model_config
)
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
...
@@ -81,38 +76,18 @@ class CacheEngine:
...
@@ -81,38 +76,18 @@ class CacheEngine:
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
# Align entries so they are 256 byte aligned for better performance
# Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.
if
self
.
align_cache
:
# We assume the cache shape is:
# (TOTAL_PAGES, PAGE_SIZE, entry_shape...)
# NOTE this assumption currently only holds for MLA so we only apply
# this optimization when `use_mla` is true
entry_shape
=
kv_cache_shape
[
2
:]
entry_size
=
prod
(
entry_shape
)
alloc_entry_size
=
align_to_256bytes
(
entry_size
,
self
.
dtype
)
alloc_shape
=
(
*
kv_cache_shape
[:
2
],
alloc_entry_size
)
else
:
alloc_shape
=
kv_cache_shape
for
_
in
range
(
self
.
num_attention_layers
):
for
_
in
range
(
self
.
num_attention_layers
):
# null block in CpuGpuBlockAllocator requires at least that
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# block to be zeroed-out.
# We zero-out everything for simplicity.
# We zero-out everything for simplicity.
layer_kv_cache
=
torch
.
zeros
(
alloc
_shape
,
layer_kv_cache
=
torch
.
zeros
(
kv_cache
_shape
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
device
=
device
)
device
=
device
)
# If we allocated with padding for alignment reasons truncate the
# shape while preserving the aligned stride
if
self
.
align_cache
:
layer_kv_cache
=
layer_kv_cache
[...,
:
entry_size
]
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
# when entry_shape is higher than 1D
kv_cache
.
append
(
layer_kv_cache
.
view
(
kv_cache_shape
)
)
kv_cache
.
append
(
layer_kv_cache
)
return
kv_cache
return
kv_cache
def
swap_in
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
def
swap_in
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
...
@@ -128,14 +103,6 @@ class CacheEngine:
...
@@ -128,14 +103,6 @@ class CacheEngine:
def
copy
(
self
,
src_to_dsts
:
torch
.
Tensor
)
->
None
:
def
copy
(
self
,
src_to_dsts
:
torch
.
Tensor
)
->
None
:
self
.
attn_backend
.
copy_blocks
(
self
.
gpu_cache
,
src_to_dsts
)
self
.
attn_backend
.
copy_blocks
(
self
.
gpu_cache
,
src_to_dsts
)
@
staticmethod
def
_align_cache
(
model_config
:
ModelConfig
):
# Currently align_cache only applies to MLA models since the other
# cache kernels haven't been updated yet to support non-continguous
# tensors
return
model_config
.
use_mla
and
current_platform
.
is_cuda
()
\
and
envs
.
VLLM_CUDA_MEM_ALIGN_KV_CACHE
@
staticmethod
@
staticmethod
def
get_cache_block_size
(
def
get_cache_block_size
(
cache_config
:
CacheConfig
,
cache_config
:
CacheConfig
,
...
@@ -153,9 +120,6 @@ class CacheEngine:
...
@@ -153,9 +120,6 @@ class CacheEngine:
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
key_cache_entry
=
num_heads
*
head_size
key_cache_entry
=
num_heads
*
head_size
if
CacheEngine
.
_align_cache
(
model_config
):
key_cache_entry
=
align_to_256bytes
(
key_cache_entry
,
model_config
.
dtype
)
# For MLA there is no value cache, since the latent vector
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
# is joint keys and values.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment