kv_cache_interface.py 22.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import copy
7
from dataclasses import dataclass, fields, replace
8
from enum import IntEnum
Chen Zhang's avatar
Chen Zhang committed
9
from math import prod
10
from typing import TYPE_CHECKING
11
12

import torch
13
from typing_extensions import Self
14
15

from vllm.logger import init_logger
16
17
18

if TYPE_CHECKING:
    from vllm.config import VllmConfig
19
from vllm.utils.math_utils import cdiv
20
from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim
21
22
23
24

logger = init_logger(__name__)


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# ---------------------------------------------------------------------------
# KV cache quantization mode
# ---------------------------------------------------------------------------


class KVQuantMode(IntEnum):
    """KV cache quantization mode.

    Used by attention backends and kernels to dispatch quantization logic
    without string matching on ``kv_cache_dtype``.
    """

    NONE = 0
    FP8_PER_TENSOR = 1  # per-tensor scales (current fp8 path)
    INT8_PER_TOKEN_HEAD = 2  # per-token-head dynamic scales for int8
    FP8_PER_TOKEN_HEAD = 3  # per-token-head dynamic scales for fp8
41
    NVFP4 = 4  # packed fp4 data + fp8 block scales
42
43
44
45

    @property
    def is_per_token_head(self) -> bool:
        """True for any per-token-head quantization mode."""
46
47
48
49
50
51
52
53
54
        return self in (
            KVQuantMode.INT8_PER_TOKEN_HEAD,
            KVQuantMode.FP8_PER_TOKEN_HEAD,
        )

    @property
    def is_nvfp4(self) -> bool:
        """True for NVFP4 packed quantization mode."""
        return self == KVQuantMode.NVFP4
55
56
57
58
59
60
61
62


def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
    """Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`."""
    if kv_cache_dtype == "int8_per_token_head":
        return KVQuantMode.INT8_PER_TOKEN_HEAD
    if kv_cache_dtype == "fp8_per_token_head":
        return KVQuantMode.FP8_PER_TOKEN_HEAD
63
64
65
    if kv_cache_dtype == "nvfp4":
        return KVQuantMode.NVFP4
    if isinstance(kv_cache_dtype, str) and kv_cache_dtype.startswith("fp8"):
66
67
68
69
70
71
72
73
74
75
76
77
78
        return KVQuantMode.FP8_PER_TENSOR
    return KVQuantMode.NONE


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return get_kv_quant_mode(kv_cache_dtype) != KVQuantMode.NONE


def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
    """Return True if *kv_cache_dtype* needs per-token-head scales."""
    return get_kv_quant_mode(kv_cache_dtype).is_per_token_head


79
@dataclass(frozen=True)
80
class KVCacheSpec:
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

98
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
99
        """
100
        The maximum possible memory usage of this KV cache in bytes.
101
102

        Returns:
103
            The KV cache size in bytes
104
105
106
        """
        raise NotImplementedError

107
108
109
110
111
112
    def copy_with_new_block_size(self, block_size: int) -> Self:
        """
        Create a new KVCacheSpec from self but replacing the block size.
        """
        return replace(self, block_size=block_size)

113
114
115
116
117
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
118
        assert all(spec == specs[0] for spec in specs[1:]), (
119
120
            "All layers in the same KV cache group must be the same."
        )
121
122
        return copy.deepcopy(specs[0])

123

124
@dataclass(frozen=True, kw_only=True)
125
class AttentionSpec(KVCacheSpec):
126
127
128
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype
129
    kv_quant_mode: KVQuantMode = KVQuantMode.NONE
130
    page_size_padded: int | None = None
131
132
133

    @property
    def page_size_bytes(self) -> int:
134
        real_page_size = self.real_page_size_bytes
135
136
137
138
139
140
141
        # Per-token-head scales are stored in separate tensors managed
        # by the attention backend, but the memory is carved from the
        # raw KV cache allocation so it must be budgeted here.
        if self.kv_quant_mode.is_per_token_head:
            real_page_size += (
                2 * self.block_size * self.num_kv_heads * get_dtype_size(torch.float32)
            )
142
143
144
145
146
147
148
        if self.page_size_padded is not None:
            assert self.page_size_padded >= real_page_size
            return self.page_size_padded
        return real_page_size

    @property
    def real_page_size_bytes(self) -> int:
149
150
151
152
153
154
155
        return (
            2
            * self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
156

157

158
@dataclass(frozen=True, kw_only=True)
159
class FullAttentionSpec(AttentionSpec):
160
    """
161
162
163
164
    When hybrid allocator is disabled and the model contains both full
    attention layers and sliding window attention layers, sliding
    window attention are regarded as full attention in KV cache manager
    (blocks are allocated for all tokens), while computed as sliding window
165
166
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
167
168
    """

169
    head_size_v: int = None  # type: ignore[assignment]
170

171
172
    sliding_window: int | None = None
    """
173
174
    Default to None for not using sliding window attention.
    """
175
    attention_chunk_size: int | None = None
176

177
178
179
180
    def __post_init__(self):
        if self.head_size_v is None:
            object.__setattr__(self, "head_size_v", self.head_size)

181
182
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
183
        dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
184
        pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
185
186
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
187
188
        if dcp_world_size * pcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
189
190
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

191
    @classmethod
192
    def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
193
194
195
196
197
198
199
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
200
201
                "same window size."
            )
202

203
204
205
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
206
        Merge a list of FullAttentionSpec objects into a single
207
208
        FullAttentionSpec object.
        """
209
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
210
211
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )
212

213
214
215
216
217
218
219
220
        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
221
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
222
223
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
224
225
226
227
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
228
            head_size_v=specs[0].head_size_v,
229
            dtype=specs[0].dtype,
230
            kv_quant_mode=specs[0].kv_quant_mode,
231
            page_size_padded=specs[0].page_size_padded,
232
233
234
235
236
237
238
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
239
240
241
242
243
244
245
246
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
247
248
        return merged_spec

249
    @property
250
    def real_page_size_bytes(self) -> int:
251
252
253
254
255
256
257
258
259
260
261
262
263
        if self.kv_quant_mode.is_nvfp4:
            # Packed layout per head: fp4 data + fp8 block scales.
            # fp4 data: head_size//2 bytes (2 fp4 values per byte)
            # fp8 block scale: head_size//16 bytes (1 scale per 16 elements)
            last_dim = nvfp4_kv_cache_full_dim(
                self.head_size
            ) + nvfp4_kv_cache_full_dim(self.head_size_v)
            return (
                self.block_size
                * self.num_kv_heads
                * last_dim
                * get_dtype_size(self.dtype)
            )
264
265
266
267
268
269
270
        return (
            self.block_size
            * self.num_kv_heads
            * (self.head_size + self.head_size_v)
            * get_dtype_size(self.dtype)
        )

271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
@dataclass(frozen=True, kw_only=True)
class TQFullAttentionSpec(FullAttentionSpec):
    """FullAttentionSpec with TQ-aware page size.

    Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
    real_page_size_bytes to use TQ slot bytes instead of the raw
    head_size * dtype formula.
    """

    tq_slot_size: int = 0

    @property
    def real_page_size_bytes(self) -> int:
        if self.tq_slot_size > 0:
            return self.block_size * self.num_kv_heads * self.tq_slot_size
        return super().real_page_size_bytes

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        merged = super().merge(specs)
        assert all(s.tq_slot_size == specs[0].tq_slot_size for s in specs), (
            "All TQ layers in the same KV cache group must use the same tq_slot_size."
        )
        return replace(merged, tq_slot_size=specs[0].tq_slot_size)


298
@dataclass(frozen=True, kw_only=True)
299
300
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
301
    cache_dtype_str: str | None = None
302
303

    @property
304
    def real_page_size_bytes(self) -> int:
305
306
307
308
        if self.cache_dtype_str == "fp8_ds_mla":
            # See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
            #  for details.
            return self.block_size * 656
309
310
311
312
313
314
        return (
            self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
315
316
317
318

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
319
320
            "All attention layers in the same KV cache group must be MLAAttentionSpec."
        )
321
322
323
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        assert len(cache_dtype_str_set) == 1, (
            "All attention layers in the same KV cache group must use the same "
324
325
            "quantization method."
        )
326
327
328
329
330
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
331
            kv_quant_mode=specs[0].kv_quant_mode,
332
            page_size_padded=specs[0].page_size_padded,
333
334
335
336
            cache_dtype_str=cache_dtype_str_set.pop(),
        )


337
@dataclass(frozen=True, kw_only=True)
338
339
340
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

341
342
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
343
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
344
345
346
347
348

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
349
350
351
        num_tokens = min(
            self.attention_chunk_size + max_num_batched_tokens, max_model_len
        )
352
353
354

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

355

356
@dataclass(frozen=True, kw_only=True)
357
358
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    head_size_v: int = None  # type: ignore[assignment]

    def __post_init__(self):
        if self.head_size_v is None:
            object.__setattr__(self, "head_size_v", self.head_size)

    @property
    def real_page_size_bytes(self) -> int:
        return (
            self.block_size
            * self.num_kv_heads
            * (self.head_size + self.head_size_v)
            * get_dtype_size(self.dtype)
        )
373
374

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
375
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
376
            "DCP not support sliding window."
377
        )
378
        max_model_len = vllm_config.model_config.max_model_len
379
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
380
381
382
383
384

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
385
386
387
        num_tokens = min(
            self.sliding_window - 1 + max_num_batched_tokens, max_model_len
        )
388
389
390
391
392
393

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
394
395


396
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
397
398
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
399
    dtypes: tuple[torch.dtype]
400
    page_size_padded: int | None = None
401
    mamba_type: str = "mamba2"
402
    mamba_cache_mode: str = "none"
403
    num_speculative_blocks: int = 0
Chen Zhang's avatar
Chen Zhang committed
404
405
406

    @property
    def page_size_bytes(self) -> int:
407
408
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
409
410
            for (shape, dtype) in zip(self.shapes, self.dtypes)
        )
411
412
413
414
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
415
416

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
417
418
419
420
421
422
423
        if vllm_config.cache_config.mamba_cache_mode == "all":
            max_model_len = vllm_config.model_config.max_model_len
            return cdiv(max_model_len, self.block_size) * self.page_size_bytes
        elif vllm_config.cache_config.mamba_cache_mode == "align":
            return self.page_size_bytes * (2 + self.num_speculative_blocks)
        else:
            return self.page_size_bytes * (1 + self.num_speculative_blocks)
Chen Zhang's avatar
Chen Zhang committed
424
425


426
427
428
429
430
431
432
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


433
434
435
436
437
438
439
440
441
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
442
        max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
443
444
445
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
@dataclass(frozen=True)
class SinkFullAttentionSpec(FullAttentionSpec):
    sink_len: int | None = None

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single
        FullAttentionSpec object.
        """
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )

        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            head_size_v=specs[0].head_size_v,
            sink_len=specs[0].sink_len,
            dtype=specs[0].dtype,
478
            kv_quant_mode=specs[0].kv_quant_mode,
479
            page_size_padded=specs[0].page_size_padded,
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
        return merged_spec


498
499
500
501
502
503
504
505
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
506

507
508
509
510
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
511
        return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
512
513
514

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
515
516
517
            cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
            for spec in self.kv_cache_specs.values()
        )
518
519
520
521
522
523
524
525
526
527
528
529
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
530
531
        if isinstance(one_spec, FullAttentionSpec):
            return all(
532
533
                isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
            )
534
        elif isinstance(one_spec, CrossAttentionSpec):
535
            return all(
536
537
                isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
            )
538
539
540
541
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
542
543
                for spec in kv_cache_specs.values()
            )
544
545
546
547
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
548
549
                for spec in kv_cache_specs.values()
            )
550
551
        elif isinstance(one_spec, MambaSpec):
            return all(
552
553
554
555
                isinstance(spec, MambaSpec)
                and spec.num_speculative_blocks == one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values()
            )
556
557
558
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
559
560
                f"Unsupported KV cache spec type: {type(one_spec)}"
            )
561
562

    @classmethod
563
    def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
564
565
566
567
568
569
570
571
572
573
574
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None


575
576
577
@dataclass
class KVCacheTensor:
    """
578
    A class for specifying how the workers should initialize the KV cache.
579
    """
580

581
582
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
583
584


585
586
587
588
589
590
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
591

592
593
594
595
596
597
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


598
599
600
601
602
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
603

604
    num_blocks: int
605
    """The number of KV cache blocks"""
606
    kv_cache_tensors: list[KVCacheTensor]
607
608
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_groups: list[KVCacheGroupSpec]
609
    """
610
    The kv cache groups of the model.
611
612
613
614
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
615
    """
616
617
618
619
620
621
622
623

    @property
    def has_mamba_layers(self) -> bool:
        return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)

    @property
    def needs_kv_cache_zeroing(self) -> bool:
        return self.has_mamba_layers