Commit b3230e1a authored by Yongye Zhu's avatar Yongye Zhu Committed by simon-mo
Browse files
parent 03df0fb5
...@@ -102,6 +102,7 @@ class PallasAttentionBackend(AttentionBackend): ...@@ -102,6 +102,7 @@ class PallasAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
padded_head_size = cdiv( padded_head_size = cdiv(
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
......
...@@ -360,6 +360,7 @@ class AiterFlashAttentionBackend(AttentionBackend): ...@@ -360,6 +360,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -68,6 +68,7 @@ class TreeAttentionBackend(AttentionBackend): ...@@ -68,6 +68,7 @@ class TreeAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -171,6 +171,7 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -171,6 +171,7 @@ class TritonAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -106,6 +106,7 @@ class XFormersAttentionBackend(AttentionBackend): ...@@ -106,6 +106,7 @@ class XFormersAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): ...@@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
kv_cache_spec: The kv cache spec of each attention layer in the model kv_cache_spec: The kv cache spec of each attention layer in the model
""" """
if is_kv_cache_spec_uniform(kv_cache_spec): if is_kv_cache_spec_uniform(
kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type(
kv_cache_spec):
return return
logger.warning( logger.warning(
...@@ -1128,7 +1130,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): ...@@ -1128,7 +1130,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
num_kv_heads=spec.num_kv_heads, num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size, head_size=spec.head_size,
dtype=spec.dtype, dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window, sliding_window=spec.sliding_window,
) )
elif isinstance(spec, ChunkedLocalAttentionSpec): elif isinstance(spec, ChunkedLocalAttentionSpec):
...@@ -1137,11 +1138,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): ...@@ -1137,11 +1138,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
num_kv_heads=spec.num_kv_heads, num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size, head_size=spec.head_size,
dtype=spec.dtype, dtype=spec.dtype,
use_mla=spec.use_mla,
attention_chunk_size=spec.attention_chunk_size, attention_chunk_size=spec.attention_chunk_size,
) )
if not is_kv_cache_spec_uniform(kv_cache_spec): if not (is_kv_cache_spec_uniform(kv_cache_spec)
or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)):
raise ValueError("Hybrid KV cache manager is disabled but failed to " raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.") "convert the KV cache specs to one unified type.")
......
...@@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock ...@@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
CrossAttentionSpec, FullAttentionSpec, CrossAttentionSpec, FullAttentionSpec,
KVCacheSpec, MambaSpec, KVCacheSpec, MambaSpec,
SlidingWindowSpec) MLAAttentionSpec, SlidingWindowSpec)
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -656,6 +656,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): ...@@ -656,6 +656,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager, FullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager, SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager, MambaSpec: MambaManager,
......
...@@ -59,13 +59,10 @@ class AttentionSpec(KVCacheSpec): ...@@ -59,13 +59,10 @@ class AttentionSpec(KVCacheSpec):
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: torch.dtype dtype: torch.dtype
use_mla: bool
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector return 2 * self.block_size * self.num_kv_heads * self.head_size \
coef = 1 if self.use_mla else 2
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) * get_dtype_size(self.dtype)
...@@ -118,12 +115,13 @@ class FullAttentionSpec(AttentionSpec): ...@@ -118,12 +115,13 @@ class FullAttentionSpec(AttentionSpec):
if spec.sliding_window is not None) if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None) if spec.attention_chunk_size is not None)
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge")
merged_spec = cls( merged_spec = cls(
block_size=specs[0].block_size, block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads, num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size, head_size=specs[0].head_size,
dtype=specs[0].dtype, dtype=specs[0].dtype,
use_mla=specs[0].use_mla,
sliding_window=cls.merge_window_sizes(sliding_window), sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
) )
...@@ -140,6 +138,38 @@ class FullAttentionSpec(AttentionSpec): ...@@ -140,6 +138,38 @@ class FullAttentionSpec(AttentionSpec):
return merged_spec return merged_spec
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this
cache_dtype_str: Optional[str] = None
@property
def page_size_bytes(self) -> int:
if self.cache_dtype_str == "fp8_ds_mla":
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
# for details.
return self.block_size * 656
return self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"MLAAttentionSpec.")
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same "
"quantization method.")
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
)
@dataclass(frozen=True) @dataclass(frozen=True)
class ChunkedLocalAttentionSpec(AttentionSpec): class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int attention_chunk_size: int
...@@ -163,9 +193,6 @@ class ChunkedLocalAttentionSpec(AttentionSpec): ...@@ -163,9 +193,6 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
class SlidingWindowSpec(AttentionSpec): class SlidingWindowSpec(AttentionSpec):
sliding_window: int sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
"DCP not support sliding window." "DCP not support sliding window."
...@@ -266,9 +293,13 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): ...@@ -266,9 +293,13 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
# Different block sizes, not uniform. # Different block sizes, not uniform.
return False return False
one_spec = next(iter(kv_cache_specs.values())) one_spec = next(iter(kv_cache_specs.values()))
if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)): if isinstance(one_spec, FullAttentionSpec):
return all(
isinstance(spec, FullAttentionSpec)
for spec in kv_cache_specs.values())
elif isinstance(one_spec, CrossAttentionSpec):
return all( return all(
isinstance(spec, type(one_spec)) isinstance(spec, CrossAttentionSpec)
for spec in kv_cache_specs.values()) for spec in kv_cache_specs.values())
elif isinstance(one_spec, SlidingWindowSpec): elif isinstance(one_spec, SlidingWindowSpec):
return all( return all(
......
...@@ -17,6 +17,7 @@ from vllm.forward_context import set_forward_context ...@@ -17,6 +17,7 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -31,6 +32,7 @@ from vllm.v1.sample.metadata import SamplingMetadata ...@@ -31,6 +32,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -51,6 +53,7 @@ class EagleProposer: ...@@ -51,6 +53,7 @@ class EagleProposer:
self.method = self.speculative_config.method self.method = self.speculative_config.method
self.runner = runner self.runner = runner
self.device = device
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
...@@ -178,20 +181,30 @@ class EagleProposer: ...@@ -178,20 +181,30 @@ class EagleProposer:
assert self.runner is not None assert self.runner is not None
# Select the correct attention metadata builders for EAGLE layers. # FIXME: need to consider multiple kv_cache_groups
# Get the attention metadata builders once and reuse for later. ubatch_id = dbo_current_ubatch_id()
builder = (self._get_attention_metadata_builder() attn_metadata_builder = \
if self.attn_metadata_builder is None else self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
self.attn_metadata_builder) attn_metadata = attn_metadata_builder.build_for_drafting(
attn_metadata = builder.build_for_drafting( # type: ignore common_attn_metadata=common_attn_metadata, draft_index=0)
common_attn_metadata=common_attn_metadata, # FIXME: support hybrid kv for draft model (remove separate indexer)
draft_index=0) if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
))
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV # At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata. # cache group, thus using the same attention metadata.
per_layer_attn_metadata = {} per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
if self.use_cuda_graph and \ if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]: num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
...@@ -323,7 +336,7 @@ class EagleProposer: ...@@ -323,7 +336,7 @@ class EagleProposer:
exceeds_max_model_len, PADDING_SLOT_ID) exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata # Rebuild attention metadata
attn_metadata = builder.build_for_drafting( # type: ignore attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1) draft_index=token_index + 1)
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
...@@ -794,6 +807,10 @@ class EagleProposer: ...@@ -794,6 +807,10 @@ class EagleProposer:
self.vllm_config.speculative_config.draft_model_config self.vllm_config.speculative_config.draft_model_config
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()) get_layers_from_vllm_config(self.vllm_config, Attention).keys())
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys())
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"): with set_model_tag("eagle_head"):
...@@ -803,8 +820,25 @@ class EagleProposer: ...@@ -803,8 +820,25 @@ class EagleProposer:
draft_attn_layer_names = ( draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() - get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names) target_attn_layer_names)
indexer_layers = get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache)
draft_indexer_layer_names = (indexer_layers.keys() -
target_indexer_layer_names)
self.attn_layer_names = list(draft_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names)
self.indexer_layer_names = list(draft_indexer_layer_names)
if self.indexer_layer_names:
first_layer = self.indexer_layer_names[0]
self.draft_indexer_metadata_builder = (
indexer_layers[first_layer].get_attn_backend().get_builder_cls(
)(
indexer_layers[first_layer].get_kv_cache_spec(),
self.indexer_layer_names,
self.vllm_config,
self.device,
))
else:
self.draft_indexer_metadata_builder = None
if supports_multimodal(target_model): if supports_multimodal(target_model):
# handle multimodality # handle multimodality
......
...@@ -40,6 +40,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase ...@@ -40,6 +40,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.model_executor.models.interfaces import (SupportsMultiModal, from vllm.model_executor.models.interfaces import (SupportsMultiModal,
...@@ -80,7 +81,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, ...@@ -80,7 +81,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec, EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec, KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec, MambaSpec, MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs) UniformTypeKVCacheSpecs)
# yapf: enable # yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
...@@ -3068,7 +3070,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3068,7 +3070,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_i = (attn_group\ attn_metadata_i = (attn_group\
.get_metadata_builder(ubatch_id=ubid)\ .get_metadata_builder(ubatch_id=ubid)\
.build_for_cudagraph_capture(common_attn_metadata)) .build_for_cudagraph_capture(common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names: for layer_name in attn_group.layer_names:
assert type(attn_metadata) is list assert type(attn_metadata) is list
attn_metadata[ubid][ attn_metadata[ubid][
layer_name] = attn_metadata_i layer_name] = attn_metadata_i
...@@ -3076,7 +3078,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3076,7 +3078,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert type(attn_metadata) is dict assert type(attn_metadata) is dict
attn_metadata_i = attn_group.get_metadata_builder()\ attn_metadata_i = attn_group.get_metadata_builder()\
.build_for_cudagraph_capture(common_attn_metadata) .build_for_cudagraph_capture(common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names: for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
...@@ -3823,8 +3825,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3823,8 +3825,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True has_attn = True
kv_cache_shape = attn_backend.get_kv_cache_shape( kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, num_blocks,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype)
dtype = kv_cache_spec.dtype dtype = kv_cache_spec.dtype
try: try:
kv_cache_stride_order = \ kv_cache_stride_order = \
...@@ -4010,7 +4015,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4010,7 +4015,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Add encoder-only layers to the KV cache config. Add encoder-only layers to the KV cache config.
""" """
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
encoder_only_attn_specs: dict[AttentionSpec, encoder_only_attn_specs: dict[AttentionSpec,
list[str]] = defaultdict(list) list[str]] = defaultdict(list)
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
...@@ -4020,8 +4024,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4020,8 +4024,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype)
use_mla=use_mla)
encoder_only_attn_specs[attn_spec].append(layer_name) encoder_only_attn_specs[attn_spec].append(layer_name)
self.runner_only_attn_layers.add(layer_name) self.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0: if len(encoder_only_attn_specs) > 0:
...@@ -4043,6 +4046,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4043,6 +4046,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla use_mla = self.vllm_config.model_config.use_mla
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items(): for layer_name, attn_module in attn_layers.items():
...@@ -4062,13 +4066,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4062,13 +4066,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the attention backends # the attention backends
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None: if attn_module.sliding_window is not None:
assert not use_mla, "MLA is not supported for sliding" \
"window"
kv_cache_spec[layer_name] = SlidingWindowSpec( kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window, sliding_window=attn_module.sliding_window)
use_mla=use_mla) elif use_mla:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str)
elif self.attention_chunk_size is not None \ elif self.attention_chunk_size is not None \
and isinstance(attn_module, ChunkedLocalAttention): and isinstance(attn_module, ChunkedLocalAttention):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
...@@ -4076,22 +4088,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4076,22 +4088,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size, attention_chunk_size=self.attention_chunk_size)
use_mla=use_mla)
else: else:
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype)
use_mla=use_mla)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER: elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec( kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype)
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache. # encoder-only attention does not need KV cache.
...@@ -4128,6 +4137,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4128,6 +4137,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.speculative_config.num_speculative_tokens self.speculative_config.num_speculative_tokens
if self.speculative_config else 0), if self.speculative_config else 0),
) )
ds_indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache)
for layer_name, ds_indexer_module in ds_indexer_layers.items():
kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec()
return kv_cache_spec return kv_cache_spec
......
...@@ -530,7 +530,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -530,7 +530,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window, sliding_window=attn_module.sliding_window,
use_mla=False,
) )
else: else:
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
...@@ -538,7 +537,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -538,7 +537,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
use_mla=False,
) )
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment