"vscode:/vscode.git/clone" did not exist on "a62bc0109c3864b9dc770dc637e3acd332c730ea"
kv_cache_interface.py 16.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import copy
5
from dataclasses import dataclass, fields, replace
Chen Zhang's avatar
Chen Zhang committed
6
from math import prod
7
8

import torch
9
from typing_extensions import Self
10

11
from vllm.config import VllmConfig
12
from vllm.logger import init_logger
13
from vllm.utils.math_utils import cdiv
14
from vllm.utils.torch_utils import get_dtype_size
15
16
17
18

logger = init_logger(__name__)


19
@dataclass(frozen=True)
20
class KVCacheSpec:
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

38
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
39
        """
40
        The maximum possible memory usage of this KV cache in bytes.
41
42

        Returns:
43
            The KV cache size in bytes
44
45
46
        """
        raise NotImplementedError

47
48
49
50
51
52
    def copy_with_new_block_size(self, block_size: int) -> Self:
        """
        Create a new KVCacheSpec from self but replacing the block size.
        """
        return replace(self, block_size=block_size)

53
54
55
56
57
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
58
        assert all(spec == specs[0] for spec in specs[1:]), (
59
60
            "All layers in the same KV cache group must be the same."
        )
61
62
        return copy.deepcopy(specs[0])

63

64
@dataclass(frozen=True)
65
class AttentionSpec(KVCacheSpec):
66
67
68
69
70
71
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype

    @property
    def page_size_bytes(self) -> int:
72
73
74
75
76
77
78
        return (
            2
            * self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
79

80

81
@dataclass(frozen=True)
82
class FullAttentionSpec(AttentionSpec):
83
    """
84
85
86
87
    When hybrid allocator is disabled and the model contains both full
    attention layers and sliding window attention layers, sliding
    window attention are regarded as full attention in KV cache manager
    (blocks are allocated for all tokens), while computed as sliding window
88
89
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
90
91
    """

92
93
    head_size_v: int | None = None

94
95
    sliding_window: int | None = None
    """
96
97
    Default to None for not using sliding window attention.
    """
98
    attention_chunk_size: int | None = None
99

100
101
102
103
    def __post_init__(self):
        if self.head_size_v is None:
            object.__setattr__(self, "head_size_v", self.head_size)

104
105
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
106
        dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
107
        pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
108
109
        # Note(hc): each dcp rank only need save
        # (max_model_len//dcp_world_size) tokens locally.
110
111
        if dcp_world_size * pcp_world_size > 1:
            max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
112
113
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

114
    @classmethod
115
    def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
116
117
118
119
120
121
122
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
123
124
                "same window size."
            )
125

126
127
128
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
129
        Merge a list of FullAttentionSpec objects into a single
130
131
        FullAttentionSpec object.
        """
132
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
133
134
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )
135

136
137
138
139
140
141
142
143
        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
144
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
145
146
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
147
148
149
150
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
151
            head_size_v=specs[0].head_size_v,
152
153
154
155
156
157
158
159
            dtype=specs[0].dtype,
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
160
161
162
163
164
165
166
167
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
168
169
        return merged_spec

170
171
172
173
174
175
176
177
178
    @property
    def page_size_bytes(self) -> int:
        return (
            self.block_size
            * self.num_kv_heads
            * (self.head_size + self.head_size_v)
            * get_dtype_size(self.dtype)
        )

179

180
181
182
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
    # TODO(Lucas/Chen): less hacky way to do this
183
    cache_dtype_str: str | None = None
184
185
186
187
188
189
190

    @property
    def page_size_bytes(self) -> int:
        if self.cache_dtype_str == "fp8_ds_mla":
            # See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
            #  for details.
            return self.block_size * 656
191
192
193
194
195
196
        return (
            self.block_size
            * self.num_kv_heads
            * self.head_size
            * get_dtype_size(self.dtype)
        )
197
198
199
200

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
201
202
            "All attention layers in the same KV cache group must be MLAAttentionSpec."
        )
203
204
205
        cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
        assert len(cache_dtype_str_set) == 1, (
            "All attention layers in the same KV cache group must use the same "
206
207
            "quantization method."
        )
208
209
210
211
212
213
214
215
216
        return cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            cache_dtype_str=cache_dtype_str_set.pop(),
        )


217
@dataclass(frozen=True)
218
219
220
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

221
222
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
223
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
224
225
226
227
228

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
229
230
231
        num_tokens = min(
            self.attention_chunk_size + max_num_batched_tokens, max_model_len
        )
232
233
234

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

235

236
@dataclass(frozen=True)
237
238
239
240
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
241
        assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
242
            "DCP not support sliding window."
243
        )
244
        max_model_len = vllm_config.model_config.max_model_len
245
        max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
246
247
248
249
250

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
251
252
253
        num_tokens = min(
            self.sliding_window - 1 + max_num_batched_tokens, max_model_len
        )
254
255
256
257
258
259

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
260
261


262
@dataclass(frozen=True)
Chen Zhang's avatar
Chen Zhang committed
263
264
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
265
    dtypes: tuple[torch.dtype]
266
    page_size_padded: int | None = None
267
    mamba_type: str = "mamba2"
268
    num_speculative_blocks: int = 0
Chen Zhang's avatar
Chen Zhang committed
269
270
271

    @property
    def page_size_bytes(self) -> int:
272
273
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
274
275
            for (shape, dtype) in zip(self.shapes, self.dtypes)
        )
276
277
278
279
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size
Chen Zhang's avatar
Chen Zhang committed
280
281

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
282
283
        max_model_len = vllm_config.model_config.max_model_len
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes
Chen Zhang's avatar
Chen Zhang committed
284
285


286
287
288
289
290
291
292
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0


293
294
295
296
297
298
299
300
301
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
    """
    KV cache spec for cross-attention layers in encoder-decoder models.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # For cross-attention, we need to cache encoder states
        # Get encoder length (e.g., 1500 for Whisper).
302
        max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
303
304
305
        return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
@dataclass(frozen=True)
class SinkFullAttentionSpec(FullAttentionSpec):
    sink_len: int | None = None

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single
        FullAttentionSpec object.
        """
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be FullAttentionSpec."
        )

        sliding_window = set(
            spec.sliding_window for spec in specs if spec.sliding_window is not None
        )
        attention_chunk_size = set(
            spec.attention_chunk_size
            for spec in specs
            if spec.attention_chunk_size is not None
        )
        assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
            "MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
        )
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            head_size_v=specs[0].head_size_v,
            sink_len=specs[0].sink_len,
            dtype=specs[0].dtype,
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec."
                )
        assert (merged_spec.sliding_window is not None) + (
            merged_spec.attention_chunk_size is not None
        ) <= 1, (
            "Model with both sliding window layers and chunked local attention "
            "layers is not supported."
        )
        return merged_spec


356
357
358
359
360
361
362
363
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
    """
    A KV cache spec for multiple layers with the same type of attention. Here,
    same types means always need the same number of token slots. For example,
    sliding window attentions with different window sizes are not the same type
    and should not be merged into one UniformTypeKVCacheSpecs.
    """
364

365
366
367
368
    kv_cache_specs: dict[str, KVCacheSpec]

    @property
    def page_size_bytes(self) -> int:
369
        return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
370
371
372

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_num_pages = max(
373
374
375
            cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
            for spec in self.kv_cache_specs.values()
        )
376
377
378
379
380
381
382
383
384
385
386
387
        return max_num_pages * self.page_size_bytes

    @classmethod
    def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
        """
        Whether all layers have the same type of KV cache spec.
        """
        block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
        if len(block_sizes) > 1:
            # Different block sizes, not uniform.
            return False
        one_spec = next(iter(kv_cache_specs.values()))
388
389
        if isinstance(one_spec, FullAttentionSpec):
            return all(
390
391
                isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
            )
392
        elif isinstance(one_spec, CrossAttentionSpec):
393
            return all(
394
395
                isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
            )
396
397
398
399
        elif isinstance(one_spec, SlidingWindowSpec):
            return all(
                isinstance(spec, SlidingWindowSpec)
                and spec.sliding_window == one_spec.sliding_window
400
401
                for spec in kv_cache_specs.values()
            )
402
403
404
405
        elif isinstance(one_spec, ChunkedLocalAttentionSpec):
            return all(
                isinstance(spec, ChunkedLocalAttentionSpec)
                and spec.attention_chunk_size == one_spec.attention_chunk_size
406
407
                for spec in kv_cache_specs.values()
            )
408
409
        elif isinstance(one_spec, MambaSpec):
            return all(
410
411
412
413
                isinstance(spec, MambaSpec)
                and spec.num_speculative_blocks == one_spec.num_speculative_blocks
                for spec in kv_cache_specs.values()
            )
414
415
416
        else:
            # NOTE(Chen): Please add new branches for new KV cache spec types.
            raise NotImplementedError(
417
418
                f"Unsupported KV cache spec type: {type(one_spec)}"
            )
419
420

    @classmethod
421
    def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
422
423
424
425
426
427
428
429
430
431
432
        """
        Return a SameTypeKVCacheSpecs object if all layers have the same type
        of KV cache spec. Return None if not.
        """
        if cls.is_uniform_type(kv_cache_specs):
            block_size = next(iter(kv_cache_specs.values())).block_size
            return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
        else:
            return None


433
434
435
@dataclass
class KVCacheTensor:
    """
436
    A class for specifying how the workers should initialize the KV cache.
437
    """
438

439
440
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
441
442


443
444
445
446
447
448
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
449

450
451
452
453
454
455
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


456
457
458
459
460
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
461

462
    num_blocks: int
463
    """The number of KV cache blocks"""
464
    kv_cache_tensors: list[KVCacheTensor]
465
466
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_groups: list[KVCacheGroupSpec]
467
    """
468
    The kv cache groups of the model.
469
470
471
472
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
473
    """