kv_cache.py 8.05 KB
Newer Older
1
from typing import Tuple
2
from dataclasses import dataclass, field
3

4
from loguru import logger
5
import torch
6
7

from text_generation_server.layers.fp8 import fp8_quantize
8
9
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights


@dataclass
class KVScales:
    """
    Key-value scales for FP8 KV cache.

    This data class stores key and value scales both as a GPU tensor and
    as a GPU float. This inconvenience is necessary because some functions
    (e.g. scaling kernels) take scales as a GPU tensor, whereas others
    (e.g. flashinfer) take scales as a CPU scalar.
    """

    key_scale: torch.Tensor
    value_scale: torch.Tensor
    key_scale_cpu: float = field(init=False)
    value_scale_cpu: float = field(init=False)

    def __post_init__(self):
        if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
            raise ValueError("Key and value scales must be scalar tensors.")

        self.key_scale_cpu = self.key_scale.item()
        self.value_scale_cpu = self.value_scale.item()
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55


class KVCache:
    """
    Key-value cache for attention layers.
    """

    kv_cache: Tuple[torch.Tensor, torch.Tensor]

    def __init__(
        self,
        *,
        num_blocks: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        """Construct the key-value cache for a layer."""

56
        if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
57
            ATTENTION != "flashinfer" or SYSTEM != "cuda"
58
59
        ):
            raise ValueError(
60
                "FP8 KV cache is currently only supported for flashinfer on CUDA"
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            )

        element_size = torch.tensor([], dtype=dtype).element_size()
        if SYSTEM == "ipex" and device.type == "xpu":
            x = 1
        else:
            x = BLOCK_SIZE // element_size

        if ATTENTION in {"flashdecoding", "flashinfer"}:
            self.kv_cache = (
                torch.empty(
                    (num_blocks, BLOCK_SIZE, num_heads, head_size),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, BLOCK_SIZE, num_heads, head_size),
                    dtype=dtype,
                    device=device,
                ),
            )
        elif SYSTEM == "ipex" and device == torch.device("cpu"):
            self.kv_cache = (
                torch.empty(
                    (num_blocks, num_heads, BLOCK_SIZE, head_size),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, num_heads, BLOCK_SIZE, head_size),
                    dtype=dtype,
                    device=device,
                ),
            )
        else:
            self.kv_cache = (
                torch.zeros(
                    (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
                    dtype=dtype,
                    device=device,
                ),
                torch.zeros(
                    (num_blocks, num_heads, head_size, BLOCK_SIZE),
                    dtype=dtype,
                    device=device,
                ),
            )

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    def can_scale(self, kv_scales: KVScales) -> bool:
        """Check if the cache can be scaled by the given scales."""
        if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
            return False
        elif (
            self.dtype == torch.float8_e4m3fn
            and ATTENTION == "flashinfer"
            and SYSTEM == "cuda"
        ):
            log_once(
                logger.info,
                "Using FP8 KV cache scales",
            )
            return True
        else:
            # We have scales, but not the correct FP8 cache type, so warn once.
            log_once(
                logger.info,
                "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
            )
            return False

    @property
    def dtype(self):
        """Get the data type of the cache."""
        return self.kv_cache[0].dtype

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    @property
    def key(self):
        """Get the key cache."""

        return self.kv_cache[0]

    @property
    def value(self):
        """Get the value cache."""

        return self.kv_cache[1]

    def store(
        self,
        *,
        key: torch.Tensor,
        value: torch.Tensor,
        slots: torch.Tensor,
154
        kv_scales: KVScales,
155
156
157
158
159
160
    ):
        """Store the key and value at the given slots."""

        key_cache = self.kv_cache[0]
        value_cache = self.kv_cache[1]

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        if self.can_scale(kv_scales):
            if kv_scales.key_scale_cpu != 1.0:
                key = fp8_quantize(
                    key.float(),
                    scale=kv_scales.key_scale,
                    qdtype=self.dtype,
                    scalar=True,
                )[0]
            if kv_scales.value_scale_cpu != 1.0:
                value = fp8_quantize(
                    value.float(),
                    scale=kv_scales.value_scale,
                    qdtype=self.dtype,
                    scalar=True,
                )[0]

177
178
179
        if ATTENTION in {"flashdecoding", "flashinfer"}:
            key = key.to(key_cache.dtype)
            value = value.to(value_cache.dtype)
180
            if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
181
                # Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
182
183
184
185
186
187
188
189
190
                # put as raw data instead.
                key_cache = key_cache.view(torch.uint8)
                value_cache = value_cache.view(torch.uint8)
                key = key.view(torch.uint8)
                value = value.view(torch.uint8)
            shape = key_cache.shape
            key_cache.view(-1, shape[-2], shape[-1])[slots] = key
            value_cache.view(-1, shape[-2], shape[-1])[slots] = value
        else:
191
192
193
194
195
196
197
198
199
200
201
202
            paged_reshape_and_cache(key, value, key_cache, value_cache, slots)


def paged_reshape_and_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slots: torch.Tensor,
):
    if SYSTEM == "cuda":
        try:
203
            import attention_kernels
204
205
        except Exception as e:
            raise ImportError(
206
                f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
207
            )
208
        attention_kernels.reshape_and_cache(
209
210
211
212
213
214
215
216
217
            key, value, key_cache, value_cache, slots, "auto", 1.0
        )
    elif SYSTEM == "rocm":
        try:
            import vllm._custom_ops as ops
        except Exception as e:
            raise ImportError(
                f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
            )
xuxzh1's avatar
xuxzh1 committed
218
        ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0)
219
220
221
222
223
224
225
226
    elif SYSTEM == "ipex":
        import intel_extension_for_pytorch as ipex

        ipex.llm.modules.PagedAttention.reshape_and_cache(
            key, value, key_cache, value_cache, slots
        )
    else:
        raise NotImplementedError(
227
            f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
228
        )
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246


def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
    """Load KV cache scales."""

    key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
    value_scale = key_scale
    if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
        f"{prefix}.v_scale"
    ):
        key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
        value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
    elif weights.has_tensor(f"{prefix}.kv_scale"):
        # Fall back to older more coarse-grained scale when available.
        key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
        value_scale = key_scale

    return KVScales(key_scale=key_scale, value_scale=value_scale)