kv_cache_interface.py 28.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import copy
7
from collections import Counter
8
from dataclasses import dataclass, fields, replace
9
from enum import IntEnum
Chen Zhang's avatar
Chen Zhang committed
10
from math import prod
11
from typing import TYPE_CHECKING
12
13

import torch
14
from typing_extensions import Self
15
16

from vllm.logger import init_logger
17
18
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
19
20
21

if TYPE_CHECKING:
    from vllm.config import VllmConfig
22
23
24
25

logger = init_logger(__name__)


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# ---------------------------------------------------------------------------
# KV cache quantization mode
# ---------------------------------------------------------------------------


class KVQuantMode(IntEnum):
    """KV cache quantization mode.

    Used by attention backends and kernels to dispatch quantization logic
    without string matching on ``kv_cache_dtype``.
    """

    NONE = 0
    FP8_PER_TENSOR = 1  # per-tensor scales (current fp8 path)
    INT8_PER_TOKEN_HEAD = 2  # per-token-head dynamic scales for int8
    FP8_PER_TOKEN_HEAD = 3  # per-token-head dynamic scales for fp8
42
    NVFP4 = 4  # packed fp4 data + fp8 block scales
43
44
45
46

    @property
    def is_per_token_head(self) -> bool:
        """True for any per-token-head quantization mode."""
47
48
49
50
51
52
53
54
55
        return self in (
            KVQuantMode.INT8_PER_TOKEN_HEAD,
            KVQuantMode.FP8_PER_TOKEN_HEAD,
        )

    @property
    def is_nvfp4(self) -> bool:
        """True for NVFP4 packed quantization mode."""
        return self == KVQuantMode.NVFP4
56
57
58
59
60
61
62
63


def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
    """Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`."""
    if kv_cache_dtype == "int8_per_token_head":
        return KVQuantMode.INT8_PER_TOKEN_HEAD
    if kv_cache_dtype == "fp8_per_token_head":
        return KVQuantMode.FP8_PER_TOKEN_HEAD
64
65
66
    if kv_cache_dtype == "nvfp4":
        return KVQuantMode.NVFP4
    if isinstance(kv_cache_dtype, str) and kv_cache_dtype.startswith("fp8"):
67
68
69
70
71
72
73
74
75
76
77
78
79
        return KVQuantMode.FP8_PER_TENSOR
    return KVQuantMode.NONE


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return get_kv_quant_mode(kv_cache_dtype) != KVQuantMode.NONE


def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
    """Return True if *kv_cache_dtype* needs per-token-head scales."""
    return get_kv_quant_mode(kv_cache_dtype).is_per_token_head


80
@dataclass(frozen=True)
81
class KVCacheSpec:
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

99
100
101
102
    @property
    def storage_block_size(self) -> int:
        return self.block_size

103
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
104
        """
105
        The maximum possible memory usage of this KV cache in bytes.
106
107

        Returns:
108
            The KV cache size in bytes
109
110
111
        """
        raise NotImplementedError

112
113
114
115
116
117
    def copy_with_new_block_size(self, block_size: int) -> Self:
        """
        Create a new KVCacheSpec from self but replacing the block size.
        """
        return replace(self, block_size=block_size)

118
119
120
121
122
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
123
        assert all(spec == specs[0] for spec in specs[1:]), (
124
125
            "All layers in the same KV cache group must be the same."
        )
126
127
        return copy.deepcopy(specs[0])

128

129
@dataclass(frozen=True, kw_only=True)
130
class AttentionSpec(KVCacheSpec):
131
132
133
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype
134
    kv_quant_mode: KVQuantMode = KVQuantMode.NONE
135
    page_size_padded: int | None = None
136
137
138

    @property
    def page_size_bytes(self) -> int:
139
        real_page_size = self.real_page_size_bytes
140
141
142
143
144
145
146
        # Per-token-head scales are stored in separate tensors managed
        # by the attention backend, but the memory is carved from the
        # raw KV cache allocation so it must be budgeted here.
        if self.kv_quant_mode.is_per_token_head:
            real_page_size += (
                2 * self.block_size * self.num_kv_heads * get_dtype_size(torch.float32)
            )
147
148
149
150
151
152
153
        if self.page_size_padded is not None:
            assert self.page_size_padded >= real_page_size
            return self.page_size_padded
        return real_page_size

    @property
    def real_page_size_bytes(self) -> int:
154
155
156
157
158
159
160
        return (
            2
            * self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
161

162

163
@dataclass(frozen=True, kw_only=True)
164
class FullAttentionSpec(AttentionSpec):
165
    """
166
167
168
169
    When hybrid allocator is disabled and the model contains both full
    attention layers and sliding window attention layers, sliding
    window attention are regarded as full attention in KV cache manager
    (blocks are allocated for all tokens), while computed as sliding window
170
171
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
172
173
    """

174
    head_size_v: int = None  # type: ignore[assignment]
175

176
177
    sliding_window: int | None = None
    """
178
179
    Default to None for not using sliding window attention.
    """
180
    attention_chunk_size: int | None = None
181

182
183
184
185
    def __post_init__(self):
        if self.head_size_v is None:
            object.__setattr__(self, "head_size_v", self.head_size)

186
187
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
188
        dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
189
        pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
190
191
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
192
193
        if dcp_world_size * pcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
194
195
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

196
    @classmethod
197
    def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
198
199
200
201
202
203
204
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
205
206
                "same window size."
            )
207

208
209
210
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
211
        Merge a list of FullAttentionSpec objects into a single
212
213
        FullAttentionSpec object.
        """
214
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
215
216
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )
217

218
219
220
221
222
223
224
225
        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
226
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
227
228
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
229
230
231
232
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
233
            head_size_v=specs[0].head_size_v,
234
            dtype=specs[0].dtype,
235
            kv_quant_mode=specs[0].kv_quant_mode,
236
            page_size_padded=specs[0].page_size_padded,
237
238
239
240
241
242
243
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
244
245
246
247
248
249
250
251
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
252
253
        return merged_spec

254
    @property
255
    def real_page_size_bytes(self) -> int:
256
257
258
259
260
261
262
263
264
265
266
267
268
        if self.kv_quant_mode.is_nvfp4:
            # Packed layout per head: fp4 data + fp8 block scales.
            # fp4 data: head_size//2 bytes (2 fp4 values per byte)
            # fp8 block scale: head_size//16 bytes (1 scale per 16 elements)
            last_dim = nvfp4_kv_cache_full_dim(
                self.head_size
            ) + nvfp4_kv_cache_full_dim(self.head_size_v)
            return (
                self.block_size
                * self.num_kv_heads
                * last_dim
                * get_dtype_size(self.dtype)
            )
269
270
271
272
273
274
275
        return (
            self.block_size
            * self.num_kv_heads
            * (self.head_size + self.head_size_v)
            * get_dtype_size(self.dtype)
        )

276

277
278
279
280
281
282
283
284
285
def _apply_alignment_padding(spec: MLAAttentionSpec | SlidingWindowMLASpec):
    if spec.alignment is None:
        return
    actual_page_size = spec.real_page_size_bytes
    padded_page_size = round_up(actual_page_size, spec.alignment)
    if padded_page_size != actual_page_size:
        object.__setattr__(spec, "page_size_padded", padded_page_size)


286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
@dataclass(frozen=True, kw_only=True)
class TQFullAttentionSpec(FullAttentionSpec):
    """FullAttentionSpec with TQ-aware page size.

    Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
    real_page_size_bytes to use TQ slot bytes instead of the raw
    head_size * dtype formula.
    """

    tq_slot_size: int = 0

    @property
    def real_page_size_bytes(self) -> int:
        if self.tq_slot_size > 0:
            return self.block_size * self.num_kv_heads * self.tq_slot_size
        return super().real_page_size_bytes

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        merged = super().merge(specs)
        assert all(s.tq_slot_size == specs[0].tq_slot_size for s in specs), (
            "All TQ layers in the same KV cache group must use the same tq_slot_size."
        )
        return replace(merged, tq_slot_size=specs[0].tq_slot_size)


312
@dataclass(frozen=True, kw_only=True)
313
314
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
315
    cache_dtype_str: str | None = None
316
317
318
319
320
321
322
323
324
325
326
327
    # DeepseekV4 only fields. Non-DeepseekV4 MLA models leave these at defaults.
    alignment: int | None = None  # Default to None for no padding.
    compress_ratio: int = 1  # Default to 1 for no compression.
    model_version: str | None = None

    def __post_init__(self):
        super().__post_init__()
        _apply_alignment_padding(self)

    @property
    def storage_block_size(self) -> int:
        return self.block_size // self.compress_ratio
328
329

    @property
330
    def real_page_size_bytes(self) -> int:
331
        if self.cache_dtype_str == "fp8_ds_mla":
332
333
334
335
336
337
            if self.model_version == "deepseek_v4":
                # DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token.
                # head_size stays semantic (512); bytes are determined here.
                return self.storage_block_size * 584
            # V3.2 main MLA: 656-byte custom layout (kv_lora_rank=512 +
            # qk_rope_head_dim=64, head_size=576). See flashmla_sparse.py.
338
            return self.block_size * 656
339
        return (
340
            self.storage_block_size
341
342
343
344
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
345
346
347
348

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
349
350
            "All attention layers in the same KV cache group must be MLAAttentionSpec."
        )
351
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
352
353
354
355
356
357
358
        compress_ratio_set = set(spec.compress_ratio for spec in specs)
        model_version_set = set(spec.model_version for spec in specs)
        assert (
            len(cache_dtype_str_set) == 1
            and len(compress_ratio_set) == 1
            and len(model_version_set) == 1
        ), (
359
            "All attention layers in the same KV cache group must use the same "
360
            "quantization method, compress ratio, and model version."
361
        )
362
363
364
365
366
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
367
            kv_quant_mode=specs[0].kv_quant_mode,
368
            page_size_padded=specs[0].page_size_padded,
369
            cache_dtype_str=cache_dtype_str_set.pop(),
370
371
            compress_ratio=compress_ratio_set.pop(),
            model_version=model_version_set.pop(),
372
373
374
        )


375
@dataclass(frozen=True, kw_only=True)
376
377
378
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

379
380
381
382
383
384
385
386
387
388
    def max_admission_blocks_per_request(
        self, max_num_batched_tokens: int, max_model_len: int
    ) -> int:
        """Per-request admission cap, in blocks.

        Single source of truth for both startup pool sizing
        (`max_memory_usage_bytes`) and the runtime admission gate, so requests
        admitted by startup can also be admitted at runtime.
        """
        # During chunked prefill, we hold KV for at most one chunk window.
389
390
391
        num_tokens = min(
            self.attention_chunk_size + max_num_batched_tokens, max_model_len
        )
392
        return cdiv(num_tokens, self.block_size)
393

394
395
396
397
398
399
400
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
        max_blocks = self.max_admission_blocks_per_request(
            max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len
        )
        return max_blocks * self.page_size_bytes
401

402

403
@dataclass(frozen=True, kw_only=True)
404
405
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    head_size_v: int = None  # type: ignore[assignment]

    def __post_init__(self):
        if self.head_size_v is None:
            object.__setattr__(self, "head_size_v", self.head_size)

    @property
    def real_page_size_bytes(self) -> int:
        return (
            self.block_size
            * self.num_kv_heads
            * (self.head_size + self.head_size_v)
            * get_dtype_size(self.dtype)
        )
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    def max_admission_blocks_per_request(
        self, max_num_batched_tokens: int, max_model_len: int
    ) -> int:
        """Per-request admission cap, in blocks.

        Single source of truth for both startup pool sizing
        (`max_memory_usage_bytes`) and the runtime admission gate. Per-request
        real-held blocks plateau at this bound because
        `SlidingWindowManager.remove_skipped_blocks` runs from `allocate_slots`
        before each chunk's `get_num_blocks_to_allocate`.
        """
        # During chunked prefill, we hold KV for the last `sliding_window-1`
        # computed tokens plus the newly scheduled tokens, and never more
        # than `max_model_len`.
        num_tokens = min(
            self.sliding_window - 1 + max_num_batched_tokens, max_model_len
        )
        # +1 because the sliding window may not start from the beginning of
        # the block. E.g. block size 4 and num_token 4 needs two blocks
        # [XXCD][EF] to store the 6-token window [CDEF].
        return cdiv(num_tokens, self.block_size) + 1

443
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
444
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
445
            "DCP not support sliding window."
446
        )
447
        max_model_len = vllm_config.model_config.max_model_len
448
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
449
450
        max_blocks = self.max_admission_blocks_per_request(
            max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len
451
        )
452
        return max_blocks * self.page_size_bytes
453
454


455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
@dataclass(frozen=True, kw_only=True)
class SlidingWindowMLASpec(SlidingWindowSpec):
    """Sliding window attention with MLA cache format."""

    cache_dtype_str: str | None = None
    # DeepseekV4-only: see MLAAttentionSpec.model_version.
    alignment: int | None = None  # Default to None for no padding.
    compress_ratio: int = 1
    model_version: str | None = None

    def __post_init__(self):
        _apply_alignment_padding(self)

    @property
    def storage_block_size(self) -> int:
        return self.block_size // self.compress_ratio

    @property
    def real_page_size_bytes(self) -> int:
        if self.model_version == "deepseek_v4":
            # DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token.
            return self.storage_block_size * 584
        assert self.model_version is None, (
            f"Unsupported model version: {self.model_version}"
        )
        return (
            self.storage_block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, SlidingWindowMLASpec) for spec in specs), (
            "All attention layers in the same KV cache group must be "
            "SlidingWindowMLASpec."
        )
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        compress_ratio_set = set(spec.compress_ratio for spec in specs)
        model_version_set = set(spec.model_version for spec in specs)
        sliding_window_set = set(spec.sliding_window for spec in specs)
        assert (
            len(cache_dtype_str_set) == 1
            and len(compress_ratio_set) == 1
            and len(model_version_set) == 1
            and len(sliding_window_set) == 1
        ), (
            "All attention layers in the same KV cache group must use the same "
            "quantization method, compress ratio, model version and sliding "
            "window size."
        )
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            page_size_padded=specs[0].page_size_padded,
            sliding_window=sliding_window_set.pop(),
            cache_dtype_str=cache_dtype_str_set.pop(),
            compress_ratio=compress_ratio_set.pop(),
            model_version=model_version_set.pop(),
        )


520
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
521
522
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
523
    dtypes: tuple[torch.dtype]
524
    page_size_padded: int | None = None
525
    mamba_type: str = "mamba2"
526
    mamba_cache_mode: str = "none"
527
    num_speculative_blocks: int = 0
Chen Zhang's avatar
Chen Zhang committed
528
529
530

    @property
    def page_size_bytes(self) -> int:
531
532
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
533
534
            for (shape, dtype) in zip(self.shapes, self.dtypes)
        )
535
536
537
538
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
539
540

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
541
542
543
544
545
546
547
        if vllm_config.cache_config.mamba_cache_mode == "all":
            max_model_len = vllm_config.model_config.max_model_len
            return cdiv(max_model_len, self.block_size) * self.page_size_bytes
        elif vllm_config.cache_config.mamba_cache_mode == "align":
            return self.page_size_bytes * (2 + self.num_speculative_blocks)
        else:
            return self.page_size_bytes * (1 + self.num_speculative_blocks)
Chen Zhang's avatar
Chen Zhang committed
548
549


550
551
552
553
554
555
556
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


557
558
559
560
561
562
563
564
565
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
566
        max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
567
568
569
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
@dataclass(frozen=True)
class SinkFullAttentionSpec(FullAttentionSpec):
    sink_len: int | None = None

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single
        FullAttentionSpec object.
        """
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )

        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            head_size_v=specs[0].head_size_v,
            sink_len=specs[0].sink_len,
            dtype=specs[0].dtype,
602
            kv_quant_mode=specs[0].kv_quant_mode,
603
            page_size_padded=specs[0].page_size_padded,
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
        return merged_spec


622
623
624
625
626
627
628
629
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
630

631
632
633
634
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
635
        return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
636
637
638

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
639
640
641
            cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
            for spec in self.kv_cache_specs.values()
        )
642
643
644
645
646
647
648
649
650
651
652
653
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
654
655
656
657
658
659
660
661
662
663
664
        # NOTE: Check subclasses before parent classes since isinstance()
        # returns True for subclasses.
        if isinstance(one_spec, SlidingWindowMLASpec):
            # SlidingWindowMLASpec is uniform if all specs are SlidingWindowMLASpec
            # with the same sliding_window size.
            return all(
                isinstance(spec, SlidingWindowMLASpec)
                and spec.sliding_window == one_spec.sliding_window
                for spec in kv_cache_specs.values()
            )
        elif isinstance(one_spec, FullAttentionSpec):
665
            return all(
666
667
                isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
            )
668
        elif isinstance(one_spec, CrossAttentionSpec):
669
            return all(
670
671
                isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
            )
672
673
674
675
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
676
677
                for spec in kv_cache_specs.values()
            )
678
679
680
681
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
682
683
                for spec in kv_cache_specs.values()
            )
684
685
        elif isinstance(one_spec, MambaSpec):
            return all(
686
687
688
689
                isinstance(spec, MambaSpec)
                and spec.num_speculative_blocks == one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values()
            )
690
691
692
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
693
694
                f"Unsupported KV cache spec type: {type(one_spec)}"
            )
695
696

    @classmethod
697
    def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
698
699
700
701
702
703
704
705
706
707
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None

708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
    # NOTE: below util functions are only used by DeepseekV4 for now.
    def get_page_sizes(self) -> list[int]:
        return list(set(spec.page_size_bytes for spec in self.kv_cache_specs.values()))

    def get_num_layer_tuples(self) -> int:
        return Counter(
            spec.page_size_bytes for spec in self.kv_cache_specs.values()
        ).most_common(1)[0][1]

    def max_memory_usage_pages(self, vllm_config: VllmConfig) -> int:
        return max(
            cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
            for spec in self.kv_cache_specs.values()
        )

723

724
725
726
@dataclass
class KVCacheTensor:
    """
727
    A class for specifying how the workers should initialize the KV cache.
728
    """
729

730
731
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
732
733


734
735
736
737
738
739
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
740

741
742
743
744
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec
745
746
    # Whether this group contains EAGLE/MTP draft attention layers.
    is_eagle_group: bool = False
747
748


749
750
751
752
753
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
754

755
    num_blocks: int
756
    """The number of KV cache blocks"""
757
    kv_cache_tensors: list[KVCacheTensor]
758
759
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_groups: list[KVCacheGroupSpec]
760
    """
761
    The kv cache groups of the model.
762
763
764
765
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
766
    """
767
768
769
770
771
772
773
774

    @property
    def has_mamba_layers(self) -> bool:
        return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)

    @property
    def needs_kv_cache_zeroing(self) -> bool:
        return self.has_mamba_layers