kv_cache_interface.py 6.96 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import copy
5
from dataclasses import dataclass
Chen Zhang's avatar
Chen Zhang committed
6
from math import prod
7
from typing import Optional
8
9

import torch
10
from typing_extensions import Self
11

12
from vllm.config import VllmConfig
13
14
15
16
17
18
19
from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size

logger = init_logger(__name__)


@dataclass
20
class KVCacheSpec:
21
22
23
24
25
26
27
28
29
30
31
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def type_id(self) -> str:
        """
        The type identifier of this KV cache.
32
33
34
        Return different strings for layers with different KV cache type (e.g.,
        different number of tokens like full attention vs sliding window
        attention, different KV cache size per token like layers with different
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        number of heads)

        Returns:
            The type identifier of this KV cache.
        """
        raise NotImplementedError

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

52
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
53
        """
54
        The maximum possible memory usage of this KV cache in bytes.
55
56

        Returns:
57
            The KV cache size in bytes
58
59
60
        """
        raise NotImplementedError

61
62
63
64
65
66
67
68
69
70
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
        assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
            "All layers in the same KV cache group must share the same "
            "type_id.")
        return copy.deepcopy(specs[0])

71
72

@dataclass
73
class AttentionSpec(KVCacheSpec):
74
75
76
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype
77
    use_mla: bool
78
79
80

    @property
    def page_size_bytes(self) -> int:
81
82
83
        # For MLA we only store a single latent vector
        coef = 1 if self.use_mla else 2
        return coef * self.block_size * self.num_kv_heads * self.head_size \
84
85
                * get_dtype_size(self.dtype)

86
87
88

@dataclass
class FullAttentionSpec(AttentionSpec):
89
90
91
92
93
94
95
96
97
98
    sliding_window: Optional[int] = None
    """
    When hybrid allocator is disabled and the model contains both full 
    attention layers and sliding window attention layers, sliding 
    window attention are regarded as full attention in KV cache manager 
    (blocks are allocated for all tokens), while computed as sliding window 
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
    Default to None for not using sliding window attention.
    """
99
100
101
102
103
104
105
106
107

    @property
    def type_id(self) -> str:
        return f"full_attention_{self.block_size}_{self.page_size_bytes}"

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single 
        FullAttentionSpec object.
        """
        merged_spec = super().merge(specs)
        sliding_window = set(spec.sliding_window for spec in specs
                             if spec.sliding_window is not None)
        if len(sliding_window) == 0:
            merged_spec.sliding_window = None
        elif len(sliding_window) == 1:
            merged_spec.sliding_window = sliding_window.pop()
        else:
            raise ValueError(
                "All sliding window layers in the same KV cache group "
                "must have the same window size.")
        return merged_spec

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

@dataclass
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def __post_init__(self):
        assert not self.use_mla, "MLA is not supported for sliding window"

    @property
    def type_id(self) -> str:
        return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}"  # noqa

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
                         max_model_len)

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
156
157


Chen Zhang's avatar
Chen Zhang committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@dataclass
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
    dtype: torch.dtype

    def __post_init__(self):
        self.num_elements = sum(prod(shape) for shape in self.shapes)

    @property
    def type_id(self) -> str:
        return f"mamba_{self.shapes}_{self.dtype}"

    @property
    def page_size_bytes(self) -> int:
        return self.num_elements * get_dtype_size(self.dtype)

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # We allocate 1 block for each request now, so max_memory_usage_bytes is
        # the same as page_size_bytes.
        # Need to update this when supporting prefix caching.
        return self.page_size_bytes


181
182
183
@dataclass
class KVCacheTensor:
    """
184
    A class for specifying how the workers should initialize the KV cache.
185
    """
186
187
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor
188
189


190
191
192
193
194
195
196
197
198
199
200
201
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


202
203
204
205
206
207
208
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
    """The number of KV cache blocks"""
    num_blocks: int
209
210
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_tensors: list[KVCacheTensor]
211
    """
212
    The kv cache groups of the model.
213
214
215
216
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
217
    """
218
    kv_cache_groups: list[KVCacheGroupSpec]