layer.py 6.89 KB
Newer Older
1
"""Attention layer."""
2
from typing import Any, Dict, List, Optional
3
4
5
6

import torch
import torch.nn as nn

7
import vllm.envs as envs
8
from vllm.attention import AttentionMetadata, AttentionType
9
from vllm.attention.selector import get_attn_backend
10
from vllm.config import CacheConfig
11
from vllm.forward_context import ForwardContext, get_forward_context
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
15
16
17
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import direct_register_custom_op
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
39
        cache_config: Optional[CacheConfig] = None,
40
        quant_config: Optional[QuantizationConfig] = None,
41
        blocksparse_params: Optional[Dict[str, Any]] = None,
42
        logits_soft_cap: Optional[float] = None,
43
        prefix: str = "",
44
45
    ) -> None:
        super().__init__()
46
47
48
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
49
            sliding_window = cache_config.sliding_window
50
            is_attention_free = cache_config.is_attention_free
51
52
53
        else:
            kv_cache_dtype = "auto"
            block_size = 16
54
            sliding_window = None
55
            is_attention_free = False
56
57
        if num_kv_heads is None:
            num_kv_heads = num_heads
58

59
        # The default k/v_scale is set to 1.0. This is ignored
60
61
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
62
        # expect the pre-quantized k/v_scale to be loaded along
63
64
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
65
66
        self._k_scale = 1.0
        self._v_scale = 1.0
67
        quant_method = quant_config.get_quant_method(
68
            self, prefix=prefix) if quant_config else None
69
        if quant_method is not None:
70
            assert isinstance(quant_method, BaseKVCacheMethod)
71
72
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
73
74
75
76
77
78
79
80
81
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
82

83
84
85
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
86
87
88
        attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
                                        block_size, is_attention_free,
                                        blocksparse_params is not None)
89
        impl_cls = attn_backend.get_impl_cls()
90
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
91
                             alibi_slopes, sliding_window, kv_cache_dtype,
92
                             blocksparse_params, logits_soft_cap)
93

94
95
96
97
98
99
100
101
102
103
104
105
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
        self.use_direct_call = envs.VLLM_USE_V1 or not (
            current_platform.is_cuda_alike() or current_platform.is_cpu())
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix

106
107
108
109
110
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
111
        kv_cache: torch.Tensor,
112
        attn_metadata: AttentionMetadata,
113
        attn_type: str = AttentionType.DECODER,
114
    ) -> torch.Tensor:
115

116
117
118
119
120
121
122
123
124
125
126
127
128
        if self.use_direct_call:
            return self.impl.forward(query,
                                     key,
                                     value,
                                     kv_cache,
                                     attn_metadata,
                                     self._k_scale,
                                     self._v_scale,
                                     attn_type=attn_type)
        else:
            return torch.ops.vllm.unified_attention(query, key, value,
                                                    kv_cache, attn_type,
                                                    self.layer_name)
129
130
131
132
133
134

    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
135
        s += f", backend={self.impl.__class__.__name__}"
136
        return s
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177


def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> torch.Tensor:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.dynamic_forward_context
    self = forward_context.static_forward_context[layer_name]
    return self.impl.forward(query,
                             key,
                             value,
                             kv_cache,
                             attn_metadata,
                             self._k_scale,
                             self._v_scale,
                             attn_type=attn_type)


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    mutates_args=["kv_cache"],
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)