layer.py 9.67 KB
Newer Older
1
"""Attention layer."""
2
from typing import Any, Dict, List, Optional
3
4
5
6

import torch
import torch.nn as nn

7
from vllm.attention import AttentionMetadata, AttentionType
8
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
9
from vllm.config import CacheConfig, get_current_vllm_config
10
from vllm.forward_context import ForwardContext, get_forward_context
11
12
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
13
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
14
from vllm.platforms import _Backend, current_platform
15
from vllm.utils import direct_register_custom_op
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36


class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
37
        cache_config: Optional[CacheConfig] = None,
38
        quant_config: Optional[QuantizationConfig] = None,
39
        blocksparse_params: Optional[Dict[str, Any]] = None,
40
        logits_soft_cap: Optional[float] = None,
41
        per_layer_sliding_window: Optional[int] = None,
42
        prefix: str = "",
43
44
    ) -> None:
        super().__init__()
45
46
47
48
49
50
51
52
53
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

54
55
56
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
57
            is_attention_free = cache_config.is_attention_free
58
59
60
        else:
            kv_cache_dtype = "auto"
            block_size = 16
61
            is_attention_free = False
62
63
        if num_kv_heads is None:
            num_kv_heads = num_heads
64

65
        # The default k/v_scale is set to 1.0. This is ignored
66
67
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
68
        # expect the pre-quantized k/v_scale to be loaded along
69
70
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
71
72
        self._k_scale = 1.0
        self._v_scale = 1.0
73
        quant_method = quant_config.get_quant_method(
74
            self, prefix=prefix) if quant_config else None
75
        if quant_method is not None:
76
            assert isinstance(quant_method, BaseKVCacheMethod)
77
78
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
79
80
81
82
83
84
85
86
87
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
88

89
90
91
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
92
93
94
        attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
                                        block_size, is_attention_free,
                                        blocksparse_params is not None)
95
        impl_cls = attn_backend.get_impl_cls()
96
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
97
                             alibi_slopes, sliding_window, kv_cache_dtype,
98
                             blocksparse_params, logits_soft_cap)
99
100
101
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
102
        self.backend = backend_name_to_enum(attn_backend.get_name())
103

104
105
106
107
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
108
109
110
111
112
113
114
115
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

        # For some attention backends, we allocate an output tensor before
        # calling the custom op. When piecewise cudagraph is enabled, this
        # makes sure the output tensor is allocated inside the cudagraph.
        self.use_output = self.backend == _Backend.FLASH_ATTN or \
            self.backend == _Backend.FLASH_ATTN_VLLM_V1
116
117
118
119
120
121
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix

122
123
124
125
126
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
127
        kv_cache: torch.Tensor,
128
        attn_metadata: AttentionMetadata,
129
        attn_type: str = AttentionType.DECODER,
130
    ) -> torch.Tensor:
131

132
133
134
135
136
137
138
139
140
        if self.use_direct_call:
            return self.impl.forward(query,
                                     key,
                                     value,
                                     kv_cache,
                                     attn_metadata,
                                     self._k_scale,
                                     self._v_scale,
                                     attn_type=attn_type)
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        elif self.use_output:
            output = torch.empty_like(query)
            hidden_size = query.size(-1)
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
            torch.ops.vllm.unified_attention_with_output(
                query, key, value, output, kv_cache, attn_type,
                self.layer_name)
            return output.view(-1, hidden_size)
157
158
159
160
        else:
            return torch.ops.vllm.unified_attention(query, key, value,
                                                    kv_cache, attn_type,
                                                    self.layer_name)
161
162
163
164
165
166

    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
167
        s += f", backend={self.impl.__class__.__name__}"
168
        return s
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209


def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> torch.Tensor:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.dynamic_forward_context
    self = forward_context.static_forward_context[layer_name]
    return self.impl.forward(query,
                             key,
                             value,
                             kv_cache,
                             attn_metadata,
                             self._k_scale,
                             self._v_scale,
                             attn_type=attn_type)


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    mutates_args=["kv_cache"],
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.dynamic_forward_context
    self = forward_context.static_forward_context[layer_name]
    self.impl.forward(query,
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
                      self._k_scale,
                      self._v_scale,
                      attn_type=attn_type,
                      output=output)


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
    mutates_args=["kv_cache", "output"],
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)