kv_cache_interface.py 7.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import copy
4
from dataclasses import dataclass
5
from typing import Optional
6
7

import torch
8
from typing_extensions import Self
9

10
from vllm.config import VllmConfig
11
12
13
14
15
16
17
from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size

logger = init_logger(__name__)


@dataclass
18
class KVCacheSpec:
19
20
21
22
23
24
25
26
27
28
29
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def type_id(self) -> str:
        """
        The type identifier of this KV cache.
30
31
32
        Return different strings for layers with different KV cache type (e.g.,
        different number of tokens like full attention vs sliding window
        attention, different KV cache size per token like layers with different
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        number of heads)

        Returns:
            The type identifier of this KV cache.
        """
        raise NotImplementedError

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

50
    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
51
        """
52
        The maximum possible memory usage of this KV cache in bytes.
53
54

        Returns:
55
            The KV cache size in bytes
56
57
58
        """
        raise NotImplementedError

59
60
61
62
63
64
65
66
67
68
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
        assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
            "All layers in the same KV cache group must share the same "
            "type_id.")
        return copy.deepcopy(specs[0])

69
70

@dataclass
71
class AttentionSpec(KVCacheSpec):
72
73
74
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype
75
    use_mla: bool
76
77
78

    @property
    def page_size_bytes(self) -> int:
79
80
81
        # For MLA we only store a single latent vector
        coef = 1 if self.use_mla else 2
        return coef * self.block_size * self.num_kv_heads * self.head_size \
82
83
                * get_dtype_size(self.dtype)

84
85
86

@dataclass
class FullAttentionSpec(AttentionSpec):
87
88
89
90
91
92
93
94
95
96
    sliding_window: Optional[int] = None
    """
    When hybrid allocator is disabled and the model contains both full 
    attention layers and sliding window attention layers, sliding 
    window attention are regarded as full attention in KV cache manager 
    (blocks are allocated for all tokens), while computed as sliding window 
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
    Default to None for not using sliding window attention.
    """
97
98
99
100
101
102
103
104
105

    @property
    def type_id(self) -> str:
        return f"full_attention_{self.block_size}_{self.page_size_bytes}"

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single 
        FullAttentionSpec object.
        """
        merged_spec = super().merge(specs)
        sliding_window = set(spec.sliding_window for spec in specs
                             if spec.sliding_window is not None)
        if len(sliding_window) == 0:
            merged_spec.sliding_window = None
        elif len(sliding_window) == 1:
            merged_spec.sliding_window = sliding_window.pop()
        else:
            raise ValueError(
                "All sliding window layers in the same KV cache group "
                "must have the same window size.")
        return merged_spec

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

@dataclass
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def __post_init__(self):
        assert not self.use_mla, "MLA is not supported for sliding window"

    @property
    def type_id(self) -> str:
        return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}"  # noqa

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
                         max_model_len)

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
154
155
156
157
158
159
160
161
162
163
164
165


@dataclass
class KVCacheTensor:
    """
    A dataclass for specifying how the workers should initialize the KV cache
    for a layer. Only contains the size of KV cache for that layer for now. Will
    be extended to support multiple layers sharing the same memory pool.
    """
    size: int  # The size of KV cache Tensor in bytes


166
167
168
169
170
171
172
173
174
175
176
177
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec


178
179
180
181
182
183
184
185
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
    """The number of KV cache blocks"""
    num_blocks: int
    """layer_name -> how to initialize KV cache for that layer"""
186
    tensors: dict[str, KVCacheTensor]
187
    """
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    The kv cache groups of the model.
    The layers in the models are repeated with some patterns, e.g., a model
    with 10 full attention layers and 20 sliding window attention layers can be
    regarded as repeating the pattern (1 * full, 2 * sw) 10 times. 
    The KVCacheManager allocates different block tables for each of the 3 layers
    in the pattern, and repeats each of them 10 times to generate the 
    block_table for the 30 layers in the model.
    Therefore, we can group the layers in the model into 3 groups, each of which
    contains 10 layers in the model.
    The KVCacheManager allocates the block_table for each group based on its
    kv_cache spec, and the model runner applies the block table to each layer 
    in the group.
    For example:
    1. A model only uses full attention. The pattern is 
    (num_hidden_layers * full), so there is only one group and the block table 
    is shared by all layers.
    2. (WIP) A model with 10 full attention layers and 20 sliding window 
    attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so 
    there are 3 groups, each of which represents 10 layers in the model.
207
    """
208
    kv_cache_groups: list[KVCacheGroupSpec]