kv_cache_interface.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import copy
7
from dataclasses import dataclass, fields, replace
8
from enum import IntEnum
Chen Zhang's avatar
Chen Zhang committed
9
from math import prod
10
from typing import TYPE_CHECKING
11
12

import torch
13
from typing_extensions import Self
14
15

from vllm.logger import init_logger
16
17
18

if TYPE_CHECKING:
    from vllm.config import VllmConfig
19
from vllm.utils.math_utils import cdiv
20
from vllm.utils.torch_utils import get_dtype_size
21
22
23
24

logger = init_logger(__name__)


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# ---------------------------------------------------------------------------
# KV cache quantization mode
# ---------------------------------------------------------------------------


class KVQuantMode(IntEnum):
    """KV cache quantization mode.

    Used by attention backends and kernels to dispatch quantization logic
    without string matching on ``kv_cache_dtype``.
    """

    NONE = 0
    FP8_PER_TENSOR = 1  # per-tensor scales (current fp8 path)
    INT8_PER_TOKEN_HEAD = 2  # per-token-head dynamic scales for int8
    FP8_PER_TOKEN_HEAD = 3  # per-token-head dynamic scales for fp8

    @property
    def is_per_token_head(self) -> bool:
        """True for any per-token-head quantization mode."""
        return self >= 2


def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
    """Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`."""
    if kv_cache_dtype == "int8_per_token_head":
        return KVQuantMode.INT8_PER_TOKEN_HEAD
    if kv_cache_dtype == "fp8_per_token_head":
        return KVQuantMode.FP8_PER_TOKEN_HEAD
    if kv_cache_dtype.startswith("fp8"):
        return KVQuantMode.FP8_PER_TENSOR
    return KVQuantMode.NONE


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return get_kv_quant_mode(kv_cache_dtype) != KVQuantMode.NONE


def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
    """Return True if *kv_cache_dtype* needs per-token-head scales."""
    return get_kv_quant_mode(kv_cache_dtype).is_per_token_head


68
@dataclass(frozen=True)
69
class KVCacheSpec:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

87
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
88
        """
89
        The maximum possible memory usage of this KV cache in bytes.
90
91

        Returns:
92
            The KV cache size in bytes
93
94
95
        """
        raise NotImplementedError

96
97
98
99
100
101
    def copy_with_new_block_size(self, block_size: int) -> Self:
        """
        Create a new KVCacheSpec from self but replacing the block size.
        """
        return replace(self, block_size=block_size)

102
103
104
105
106
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
107
        assert all(spec == specs[0] for spec in specs[1:]), (
108
109
            "All layers in the same KV cache group must be the same."
        )
110
111
        return copy.deepcopy(specs[0])

112

113
@dataclass(frozen=True, kw_only=True)
114
class AttentionSpec(KVCacheSpec):
115
116
117
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype
118
    kv_quant_mode: KVQuantMode = KVQuantMode.NONE
119
    page_size_padded: int | None = None
120
121
122

    @property
    def page_size_bytes(self) -> int:
123
        real_page_size = self.real_page_size_bytes
124
125
126
127
128
129
130
        # Per-token-head scales are stored in separate tensors managed
        # by the attention backend, but the memory is carved from the
        # raw KV cache allocation so it must be budgeted here.
        if self.kv_quant_mode.is_per_token_head:
            real_page_size += (
                2 * self.block_size * self.num_kv_heads * get_dtype_size(torch.float32)
            )
131
132
133
134
135
136
137
        if self.page_size_padded is not None:
            assert self.page_size_padded >= real_page_size
            return self.page_size_padded
        return real_page_size

    @property
    def real_page_size_bytes(self) -> int:
138
139
140
141
142
143
144
        return (
            2
            * self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
145

146

147
@dataclass(frozen=True, kw_only=True)
148
class FullAttentionSpec(AttentionSpec):
149
    """
150
151
152
153
    When hybrid allocator is disabled and the model contains both full
    attention layers and sliding window attention layers, sliding
    window attention are regarded as full attention in KV cache manager
    (blocks are allocated for all tokens), while computed as sliding window
154
155
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
156
157
    """

158
    head_size_v: int = None  # type: ignore[assignment]
159

160
161
    sliding_window: int | None = None
    """
162
163
    Default to None for not using sliding window attention.
    """
164
    attention_chunk_size: int | None = None
165

166
167
168
169
    def __post_init__(self):
        if self.head_size_v is None:
            object.__setattr__(self, "head_size_v", self.head_size)

170
171
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
172
        dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
173
        pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
174
175
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
176
177
        if dcp_world_size * pcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
178
179
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

180
    @classmethod
181
    def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
182
183
184
185
186
187
188
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
189
190
                "same window size."
            )
191

192
193
194
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
195
        Merge a list of FullAttentionSpec objects into a single
196
197
        FullAttentionSpec object.
        """
198
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
199
200
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )
201

202
203
204
205
206
207
208
209
        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
210
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
211
212
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
213
214
215
216
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
217
            head_size_v=specs[0].head_size_v,
218
            dtype=specs[0].dtype,
219
            kv_quant_mode=specs[0].kv_quant_mode,
220
            page_size_padded=specs[0].page_size_padded,
221
222
223
224
225
226
227
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
228
229
230
231
232
233
234
235
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
236
237
        return merged_spec

238
    @property
239
    def real_page_size_bytes(self) -> int:
240
241
242
243
244
245
246
        return (
            self.block_size
            * self.num_kv_heads
            * (self.head_size + self.head_size_v)
            * get_dtype_size(self.dtype)
        )

247

248
@dataclass(frozen=True, kw_only=True)
249
250
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
251
    cache_dtype_str: str | None = None
252
253

    @property
254
    def real_page_size_bytes(self) -> int:
255
256
257
258
        if self.cache_dtype_str == "fp8_ds_mla":
            # See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
            #  for details.
            return self.block_size * 656
259
260
261
262
263
264
        return (
            self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
265
266
267
268

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
269
270
            "All attention layers in the same KV cache group must be MLAAttentionSpec."
        )
271
272
273
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        assert len(cache_dtype_str_set) == 1, (
            "All attention layers in the same KV cache group must use the same "
274
275
            "quantization method."
        )
276
277
278
279
280
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
281
            kv_quant_mode=specs[0].kv_quant_mode,
282
            page_size_padded=specs[0].page_size_padded,
283
284
285
286
            cache_dtype_str=cache_dtype_str_set.pop(),
        )


287
@dataclass(frozen=True, kw_only=True)
288
289
290
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

291
292
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
293
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
294
295
296
297
298

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
299
300
301
        num_tokens = min(
            self.attention_chunk_size + max_num_batched_tokens, max_model_len
        )
302
303
304

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

305

306
@dataclass(frozen=True, kw_only=True)
307
308
309
310
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
311
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
312
            "DCP not support sliding window."
313
        )
314
        max_model_len = vllm_config.model_config.max_model_len
315
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
316
317
318
319
320

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
321
322
323
        num_tokens = min(
            self.sliding_window - 1 + max_num_batched_tokens, max_model_len
        )
324
325
326
327
328
329

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
330
331


332
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
333
334
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
335
    dtypes: tuple[torch.dtype]
336
    page_size_padded: int | None = None
337
    mamba_type: str = "mamba2"
338
    mamba_cache_mode: str = "none"
339
    num_speculative_blocks: int = 0
Chen Zhang's avatar
Chen Zhang committed
340
341
342

    @property
    def page_size_bytes(self) -> int:
343
344
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
345
346
            for (shape, dtype) in zip(self.shapes, self.dtypes)
        )
347
348
349
350
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
351
352

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
353
354
355
356
357
358
359
        if vllm_config.cache_config.mamba_cache_mode == "all":
            max_model_len = vllm_config.model_config.max_model_len
            return cdiv(max_model_len, self.block_size) * self.page_size_bytes
        elif vllm_config.cache_config.mamba_cache_mode == "align":
            return self.page_size_bytes * (2 + self.num_speculative_blocks)
        else:
            return self.page_size_bytes * (1 + self.num_speculative_blocks)
Chen Zhang's avatar
Chen Zhang committed
360
361


362
363
364
365
366
367
368
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


369
370
371
372
373
374
375
376
377
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
378
        max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
379
380
381
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
@dataclass(frozen=True)
class SinkFullAttentionSpec(FullAttentionSpec):
    sink_len: int | None = None

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single
        FullAttentionSpec object.
        """
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )

        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            head_size_v=specs[0].head_size_v,
            sink_len=specs[0].sink_len,
            dtype=specs[0].dtype,
414
            kv_quant_mode=specs[0].kv_quant_mode,
415
            page_size_padded=specs[0].page_size_padded,
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
        return merged_spec


434
435
436
437
438
439
440
441
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
442

443
444
445
446
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
447
        return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
448
449
450

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
451
452
453
            cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
            for spec in self.kv_cache_specs.values()
        )
454
455
456
457
458
459
460
461
462
463
464
465
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
466
467
        if isinstance(one_spec, FullAttentionSpec):
            return all(
468
469
                isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
            )
470
        elif isinstance(one_spec, CrossAttentionSpec):
471
            return all(
472
473
                isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
            )
474
475
476
477
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
478
479
                for spec in kv_cache_specs.values()
            )
480
481
482
483
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
484
485
                for spec in kv_cache_specs.values()
            )
486
487
        elif isinstance(one_spec, MambaSpec):
            return all(
488
489
490
491
                isinstance(spec, MambaSpec)
                and spec.num_speculative_blocks == one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values()
            )
492
493
494
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
495
496
                f"Unsupported KV cache spec type: {type(one_spec)}"
            )
497
498

    @classmethod
499
    def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
500
501
502
503
504
505
506
507
508
509
510
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None


511
512
513
@dataclass
class KVCacheTensor:
    """
514
    A class for specifying how the workers should initialize the KV cache.
515
    """
516

517
518
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
519
520


521
522
523
524
525
526
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
527

528
529
530
531
532
533
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


534
535
536
537
538
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
539

540
    num_blocks: int
541
    """The number of KV cache blocks"""
542
    kv_cache_tensors: list[KVCacheTensor]
543
544
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_groups: list[KVCacheGroupSpec]
545
    """
546
    The kv cache groups of the model.
547
548
549
550
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
551
    """
552
553
554
555
556
557
558
559

    @property
    def has_mamba_layers(self) -> bool:
        return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)

    @property
    def needs_kv_cache_zeroing(self) -> bool:
        return self.has_mamba_layers