layer.py 11.7 KB
Newer Older
1
"""Attention layer."""
2
from typing import Any, Dict, List, Optional
3
4
5

import torch
import torch.nn as nn
6
import torch.nn.functional as F
7

8
from vllm.attention import AttentionMetadata, AttentionType
9
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
10
from vllm.config import CacheConfig, get_current_vllm_config
11
from vllm.forward_context import ForwardContext, get_forward_context
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
15
from vllm.platforms import _Backend, current_platform
16
from vllm.utils import direct_register_custom_op
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
38
        cache_config: Optional[CacheConfig] = None,
39
        quant_config: Optional[QuantizationConfig] = None,
40
        blocksparse_params: Optional[Dict[str, Any]] = None,
41
        logits_soft_cap: Optional[float] = None,
42
        per_layer_sliding_window: Optional[int] = None,
43
        prefix: str = "",
44
        attn_type: str = AttentionType.DECODER,
45
46
    ) -> None:
        super().__init__()
47
48
49
50
51
52
53
54
55
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

56
57
58
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
59
            is_attention_free = cache_config.is_attention_free
60
61
62
        else:
            kv_cache_dtype = "auto"
            block_size = 16
63
            is_attention_free = False
64
65
        if num_kv_heads is None:
            num_kv_heads = num_heads
66

67
        # The default k/v_scale is set to 1.0. This is ignored
68
69
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
70
        # expect the pre-quantized k/v_scale to be loaded along
71
72
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
73
74
        self._k_scale = 1.0
        self._v_scale = 1.0
75
        quant_method = quant_config.get_quant_method(
76
            self, prefix=prefix) if quant_config else None
77
        if quant_method is not None:
78
            assert isinstance(quant_method, BaseKVCacheMethod)
79
80
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
81
82
83
84
85
86
87
88
89
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
90

91
92
93
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
94
95
96
        attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
                                        block_size, is_attention_free,
                                        blocksparse_params is not None)
97
        impl_cls = attn_backend.get_impl_cls()
98
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
99
                             alibi_slopes, sliding_window, kv_cache_dtype,
100
                             blocksparse_params, logits_soft_cap, attn_type)
101
102
103
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
104
        self.backend = backend_name_to_enum(attn_backend.get_name())
105

106
107
108
109
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
110
111
112
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

113
        self.use_output = attn_backend.accept_output_buffer
114
115
116
117
118
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
119
        self.attn_type = attn_type
120
121
122
123
124
125
126
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
127

128
129
130
131
132
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
133
134
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
135
    ) -> torch.Tensor:
136
        if self.use_output:
137
138
139
140
141
142
143
144
145
146
147
            output = torch.empty_like(query)
            hidden_size = query.size(-1)
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
148
149
150
151
152
153
            if self.use_direct_call:
                unified_attention_with_output(query, key, value, output,
                                              self.layer_name)
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
154
            return output.view(-1, hidden_size)
155
        else:
156
157
158
159
160
            if self.use_direct_call:
                return unified_attention(query, key, value, self.layer_name)
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
161
162
163
164
165
166

    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
167
        s += f", backend={self.impl.__class__.__name__}"
168
        return s
169
170


171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
                                        block_size=16,
                                        is_attention_free=False)
193
194
195
        backend = backend_name_to_enum(attn_backend.get_name())
        if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
            backend = _Backend.XFORMERS
196

197
        self.attn_backend = backend if backend in {
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
            _Backend.TORCH_SDPA, _Backend.XFORMERS
        } else _Backend.TORCH_SDPA

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
        # TODO(Isotr0py): Use existing backend implementations and support FA2
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        if self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
231
        return out.reshape(bsz, q_len, -1)
232
233


234
235
236
237
238
239
240
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    forward_context: ForwardContext = get_forward_context()
241
242
243
    attn_metadata = forward_context.attn_metadata
    self = forward_context.attn_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
244
245
    return self.impl.forward(query, key, value, kv_cache, attn_metadata,
                             self._k_scale, self._v_scale)
246
247
248
249
250
251
252
253
254
255
256
257
258
259


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
260
    mutates_args=[],
261
262
263
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
264
265
266
267
268
269
270
271
272
273


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
274
275
276
    attn_metadata = forward_context.attn_metadata
    self = forward_context.attn_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    self.impl.forward(query,
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
                      self._k_scale,
                      self._v_scale,
                      output=output)


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
300
    mutates_args=["output"],
301
302
303
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)