Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b26b70be
Unverified
Commit
b26b70be
authored
Oct 18, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Oct 18, 2025
Browse files
[Misc] Refactor `get_kv_cache_spec` into `AttentionLayerBase` (#26587)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
ab4be40f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
151 additions
and
118 deletions
+151
-118
vllm/attention/layer.py
vllm/attention/layer.py
+53
-5
vllm/attention/layers/chunked_local_attention.py
vllm/attention/layers/chunked_local_attention.py
+13
-0
vllm/attention/layers/cross_attention.py
vllm/attention/layers/cross_attention.py
+9
-1
vllm/attention/layers/encoder_only_attention.py
vllm/attention/layers/encoder_only_attention.py
+6
-0
vllm/model_executor/layers/attention_layer_base.py
vllm/model_executor/layers/attention_layer_base.py
+11
-0
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+29
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+1
-1
vllm/utils/__init__.py
vllm/utils/__init__.py
+9
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+19
-110
No files found.
vllm/attention/layer.py
View file @
b26b70be
...
...
@@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.utils.kv_sharing_utils
import
validate_kv_sharing_target
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.config.vllm
import
VllmConfig
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
,
...
...
@@ -34,7 +35,16 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.models.vision
import
get_vit_attn_backend
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
(
direct_register_custom_op
,
kv_cache_dtype_str_to_dtype
,
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
MLAAttentionSpec
,
SlidingWindowSpec
,
)
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
logger
=
init_logger
(
__name__
)
...
...
@@ -152,6 +162,7 @@ class Attention(nn.Module, AttentionLayerBase):
else
:
sliding_window
=
None
vllm_config
=
get_current_vllm_config
()
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
...
...
@@ -160,6 +171,9 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dtype
=
"auto"
block_size
=
16
calculate_kv_scales
=
False
self
.
kv_cache_torch_dtype
=
kv_cache_dtype_str_to_dtype
(
kv_cache_dtype
,
vllm_config
.
model_config
)
if
num_kv_heads
is
None
:
num_kv_heads
=
num_heads
assert
num_heads
%
num_kv_heads
==
0
,
(
...
...
@@ -256,7 +270,7 @@ class Attention(nn.Module, AttentionLayerBase):
self
.
use_direct_call
=
not
current_platform
.
opaque_attention_op
()
self
.
use_output
=
self
.
attn_backend
.
accept_output_buffer
compilation_config
=
get_current_
vllm_config
()
.
compilation_config
compilation_config
=
vllm_config
.
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
...
...
@@ -276,9 +290,7 @@ class Attention(nn.Module, AttentionLayerBase):
# this variable will not be accessed if use_direct_call is True
self
.
kv_cache
=
[
torch
.
tensor
([])
for
_
in
range
(
get_current_vllm_config
().
parallel_config
.
pipeline_parallel_size
)
for
_
in
range
(
vllm_config
.
parallel_config
.
pipeline_parallel_size
)
]
# Initialize q/k/v range constants.
...
...
@@ -394,6 +406,30 @@ class Attention(nn.Module, AttentionLayerBase):
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
return
self
.
attn_backend
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
# Block size may get updated after model loading, refresh it
block_size
=
vllm_config
.
cache_config
.
block_size
# Should not be called for enc-dec or encoder-only attention.
assert
self
.
attn_type
==
AttentionType
.
DECODER
if
self
.
sliding_window
is
not
None
:
assert
not
vllm_config
.
model_config
.
use_mla
,
(
"MLA is not supported for slidingwindow"
)
return
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
dtype
=
self
.
kv_cache_torch_dtype
,
sliding_window
=
self
.
sliding_window
,
)
else
:
return
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
dtype
=
self
.
kv_cache_torch_dtype
,
)
class
MultiHeadAttention
(
nn
.
Module
):
"""Multi-headed attention without any cache, used for ViT."""
...
...
@@ -749,6 +785,18 @@ class MLAAttention(nn.Module, AttentionLayerBase):
def
get_attn_backend
(
self
)
->
type
[
AttentionBackend
]:
return
self
.
attn_backend
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
kv_cache_dtype
=
kv_cache_dtype_str_to_dtype
(
self
.
kv_cache_dtype
,
vllm_config
.
model_config
)
return
MLAAttentionSpec
(
block_size
=
vllm_config
.
cache_config
.
block_size
,
num_kv_heads
=
1
,
head_size
=
self
.
head_size
,
dtype
=
kv_cache_dtype
,
cache_dtype_str
=
vllm_config
.
cache_config
.
cache_dtype
,
)
def
wait_for_kv_layer_from_connector
(
layer_name
:
str
):
if
not
has_kv_transfer_group
()
or
not
is_v1_kv_transfer_group
():
...
...
vllm/attention/layers/chunked_local_attention.py
View file @
b26b70be
...
...
@@ -9,6 +9,7 @@ from vllm import envs
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.config.vllm
import
VllmConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
...
...
@@ -16,6 +17,7 @@ from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches
,
subclass_attention_backend
,
)
from
vllm.v1.kv_cache_interface
import
ChunkedLocalAttentionSpec
,
KVCacheSpec
from
..layer
import
Attention
...
...
@@ -67,6 +69,7 @@ class ChunkedLocalAttention(Attention):
kv_sharing_target_layer_name
:
str
|
None
=
None
,
prefix
:
str
=
""
,
):
self
.
attention_chunk_size
=
attention_chunk_size
dtype
=
torch
.
get_default_dtype
()
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
...
...
@@ -99,3 +102,13 @@ class ChunkedLocalAttention(Attention):
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
,
attn_backend
=
attn_backend
,
)
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
assert
self
.
attention_chunk_size
return
ChunkedLocalAttentionSpec
(
block_size
=
vllm_config
.
cache_config
.
block_size
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
dtype
=
self
.
kv_cache_torch_dtype
,
attention_chunk_size
=
self
.
attention_chunk_size
,
)
vllm/attention/layers/cross_attention.py
View file @
b26b70be
...
...
@@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata
,
subclass_attention_backend
,
)
from
vllm.v1.kv_cache_interface
import
CrossAttentionSpec
from
vllm.v1.kv_cache_interface
import
CrossAttentionSpec
,
KVCacheSpec
logger
=
init_logger
(
__name__
)
...
...
@@ -174,3 +174,11 @@ class CrossAttention(Attention):
attn_type
=
AttentionType
.
ENCODER_DECODER
,
**
kwargs
,
)
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
return
CrossAttentionSpec
(
block_size
=
vllm_config
.
cache_config
.
block_size
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
dtype
=
self
.
kv_cache_torch_dtype
,
)
vllm/attention/layers/encoder_only_attention.py
View file @
b26b70be
...
...
@@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.config.vllm
import
VllmConfig
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
subclass_attention_backend
,
)
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
@
functools
.
lru_cache
...
...
@@ -98,3 +100,7 @@ class EncoderOnlyAttention(Attention):
attn_type
=
AttentionType
.
ENCODER_ONLY
,
**
kwargs
,
)
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
# Does not need KV cache
return
None
vllm/model_executor/layers/attention_layer_base.py
View file @
b26b70be
...
...
@@ -5,6 +5,9 @@
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
from
vllm.config
import
VllmConfig
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
...
...
@@ -22,3 +25,11 @@ class AttentionLayerBase(ABC):
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
"""Get the attention backend class for this layer."""
pass
@
abstractmethod
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
|
None
:
"""
Get the KV cache spec for this layer.
May be None if the layer does not need KV cache.
"""
pass
vllm/model_executor/layers/mamba/abstract.py
View file @
b26b70be
...
...
@@ -6,7 +6,9 @@ from typing import TYPE_CHECKING
import
torch
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
,
MambaSpec
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
...
...
@@ -40,3 +42,30 @@ class MambaBase(AttentionLayerBase):
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
"""Get the attention backend class for this Mamba layer."""
pass
@
abstractmethod
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
,
...]:
pass
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
|
None
:
if
(
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
model_config
.
hf_config
.
model_type
not
in
[
"qwen3_next"
]
):
raise
NotImplementedError
(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size
=
vllm_config
.
cache_config
.
mamba_block_size
page_size_padded
=
vllm_config
.
cache_config
.
mamba_page_size_padded
return
MambaSpec
(
shapes
=
self
.
get_state_shape
(),
dtypes
=
self
.
get_state_dtype
(),
block_size
=
mamba_block_size
,
page_size_padded
=
page_size_padded
,
mamba_type
=
self
.
mamba_type
,
num_speculative_blocks
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
if
vllm_config
.
speculative_config
else
0
),
)
vllm/model_executor/models/deepseek_v2.py
View file @
b26b70be
...
...
@@ -481,7 +481,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
return
MLAAttentionSpec
(
# Only has one vector instead of K + V
block_size
=
self
.
cache_config
.
block_size
,
num_kv_heads
=
1
,
...
...
vllm/utils/__init__.py
View file @
b26b70be
...
...
@@ -137,6 +137,15 @@ def set_default_torch_num_threads(num_threads: int):
torch
.
set_num_threads
(
old_num_threads
)
def
kv_cache_dtype_str_to_dtype
(
kv_cache_dtype
:
str
,
model_config
:
ModelConfig
)
->
torch
.
dtype
:
if
kv_cache_dtype
==
"auto"
:
# Model config may not be specified for unit tests, default to float16
return
model_config
.
dtype
if
model_config
else
torch
.
half
return
STR_DTYPE_TO_TORCH_DTYPE
[
kv_cache_dtype
]
T
=
TypeVar
(
"T"
)
U
=
TypeVar
(
"U"
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
b26b70be
...
...
@@ -948,7 +948,7 @@ class EagleProposer:
indexer_layers
[
first_layer
]
.
get_attn_backend
()
.
get_builder_cls
()(
indexer_layers
[
first_layer
].
get_kv_cache_spec
(),
indexer_layers
[
first_layer
].
get_kv_cache_spec
(
self
.
vllm_config
),
self
.
indexer_layer_names
,
self
.
vllm_config
,
self
.
device
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b26b70be
...
...
@@ -19,8 +19,6 @@ from tqdm import tqdm
import
vllm.envs
as
envs
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.attention.backends.abstract
import
AttentionBackend
,
MultipleOf
from
vllm.attention.layer
import
MLAAttention
from
vllm.attention.layers.chunked_local_attention
import
ChunkedLocalAttention
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.cuda_graph
import
CUDAGraphWrapper
from
vllm.compilation.monitor
import
set_cudagraph_capturing_enabled
...
...
@@ -44,10 +42,8 @@ from vllm.distributed.parallel_state import (
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.models.deepseek_v2
import
DeepseekV32IndexerCache
from
vllm.model_executor.models.interfaces
import
(
SupportsMultiModal
,
is_mixture_of_experts
,
...
...
@@ -73,11 +69,11 @@ from vllm.sampling_params import SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
GenerationTask
,
PoolingTask
,
SupportedTask
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
,
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
,
kv_cache_dtype_str_to_dtype
,
length_from_prompt_token_ids_or_embeds
,
round_up
,
supports_dynamo
,
...
...
@@ -106,7 +102,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec
,
KVCacheSpec
,
MambaSpec
,
MLAAttentionSpec
,
SlidingWindowSpec
,
UniformTypeKVCacheSpecs
,
)
...
...
@@ -239,10 +234,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
device
=
device
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
if
cache_config
.
cache_dtype
==
"auto"
:
self
.
kv_cache_dtype
=
self
.
dtype
else
:
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
self
.
kv_cache_dtype
=
kv_cache_dtype_str_to_dtype
(
cache_config
.
cache_dtype
,
self
.
model_config
)
self
.
is_pooling_model
=
model_config
.
runner_type
==
"pooling"
self
.
enable_prompt_embeds
=
model_config
.
enable_prompt_embeds
...
...
@@ -4577,109 +4571,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
use_mla
=
self
.
vllm_config
.
model_config
.
use_mla
cache_dtype_str
=
self
.
vllm_config
.
cache_config
.
cache_dtype
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
)
for
layer_name
,
attn_module
in
attn_layers
.
items
():
if
isinstance
(
attn_module
,
Attention
):
if
(
kv_tgt_layer
:
=
attn_module
.
kv_sharing_target_layer_name
)
is
not
None
:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self
.
shared_kv_cache_layers
[
layer_name
]
=
kv_tgt_layer
continue
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
assert
not
use_mla
,
"MLA is not supported for slidingwindow"
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
sliding_window
=
attn_module
.
sliding_window
,
)
elif
self
.
attention_chunk_size
is
not
None
and
isinstance
(
attn_module
,
ChunkedLocalAttention
):
kv_cache_spec
[
layer_name
]
=
ChunkedLocalAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
attention_chunk_size
=
self
.
attention_chunk_size
,
)
else
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
)
elif
attn_module
.
attn_type
==
AttentionType
.
ENCODER_DECODER
:
kv_cache_spec
[
layer_name
]
=
CrossAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
)
elif
attn_module
.
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
,
):
# encoder-only attention does not need KV cache.
continue
else
:
raise
ValueError
(
f
"Unknown attention type:
{
attn_module
.
attn_type
}
"
)
elif
isinstance
(
attn_module
,
MLAAttention
):
kv_cache_spec
[
layer_name
]
=
MLAAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
1
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
cache_dtype_str
=
cache_dtype_str
,
)
elif
isinstance
(
attn_module
,
MambaBase
):
if
(
self
.
vllm_config
.
speculative_config
is
not
None
and
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
not
in
[
"qwen3_next"
]
):
raise
NotImplementedError
(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size
=
self
.
vllm_config
.
cache_config
.
mamba_block_size
page_size_padded
=
self
.
vllm_config
.
cache_config
.
mamba_page_size_padded
kv_cache_spec
[
layer_name
]
=
MambaSpec
(
shapes
=
attn_module
.
get_state_shape
(),
dtypes
=
attn_module
.
get_state_dtype
(),
block_size
=
mamba_block_size
,
page_size_padded
=
page_size_padded
,
mamba_type
=
attn_module
.
mamba_type
,
num_speculative_blocks
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
else
0
),
)
ds_indexer_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
DeepseekV32IndexerCache
)
for
layer_name
,
ds_indexer_module
in
ds_indexer_layers
.
items
():
kv_cache_spec
[
layer_name
]
=
ds_indexer_module
.
get_kv_cache_spec
()
if
isinstance
(
attn_module
,
Attention
)
and
(
kv_tgt_layer
:
=
attn_module
.
kv_sharing_target_layer_name
):
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self
.
shared_kv_cache_layers
[
layer_name
]
=
kv_tgt_layer
continue
# Skip modules that don't need KV cache (eg encoder-only attention)
if
spec
:
=
attn_module
.
get_kv_cache_spec
(
self
.
vllm_config
):
kv_cache_spec
[
layer_name
]
=
spec
return
kv_cache_spec
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment