"platforms/common/src/kernels/langevin.cc" did not exist on "099706320d5689dfa9758fcad8eb1d6f21bdc43d"
flash_attn.py 8.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
                                              AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
10
from vllm.utils import direct_register_custom_op
11
12
13
14
15
16
17
18
19
20
21
from vllm.vllm_flash_attn import flash_attn_varlen_func


class FlashAttentionBackend(AttentionBackend):

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [32, 64, 96, 128, 160, 192, 224, 256]

    @staticmethod
    def get_name() -> str:
22
        return "FLASH_ATTN_VLLM_V1"
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    @staticmethod
    def get_impl_cls() -> Type["FlashAttentionImpl"]:
        return FlashAttentionImpl

    @staticmethod
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return FlashAttentionMetadata

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)


@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

54
    num_actual_tokens: int  # Number of tokens excluding padding.
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_start_loc: torch.Tensor
    block_table: torch.Tensor
    slot_mapping: torch.Tensor


class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]] = None,
        logits_soft_cap: Optional[float] = None,
    ) -> None:
        if blocksparse_params is not None:
            raise ValueError(
                "FlashAttention does not support block-sparse attention.")
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
87
88
89
90
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
        if head_size not in support_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by FlashAttention. "
                f"Supported head sizes are: {support_head_sizes}.")

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        k_scale: float = 1.0,
        v_scale: float = 1.0,
        attn_type: AttentionType = AttentionType.DECODER,
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashAttentionImpl")

        # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
        assert k_scale == 1.0 and v_scale == 1.0, (
            "key/v_scale is not supported in FlashAttention.")

138
139
140
141
142
143
144
        # Reshape the query, key, and value tensors.
        # NOTE(woosuk): We do this outside the custom op to minimize the CPU
        # overheads from the non-CUDA-graph regions.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

145
        output = torch.empty_like(query)
Joe Runde's avatar
Joe Runde committed
146
        torch.ops.vllm.unified_v1_flash_attention(
147
            output,
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            query,
            key,
            value,
            self.num_heads,
            self.head_size,
            self.num_kv_heads,
            kv_cache,
            self.kv_cache_dtype,
            k_scale,
            v_scale,
            self.scale,
            self.sliding_window,
            self.alibi_slopes,
            self.logits_soft_cap,
        )
163
        return output.view(-1, self.num_heads * self.head_size)
164
165


Joe Runde's avatar
Joe Runde committed
166
def unified_v1_flash_attention(
167
    output: torch.Tensor,
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    num_heads: int,
    head_size: int,
    num_kv_heads: int,
    kv_cache: torch.Tensor,
    kv_cache_dtype: str,
    k_scale: float,
    v_scale: float,
    softmax_scale: float,
    window_size: Optional[List[int]] = None,
    alibi_slopes: Optional[torch.Tensor] = None,
    logits_soft_cap: Optional[float] = None,
182
) -> None:
183
184
    context = get_forward_context()
    current_metadata = context.dynamic_forward_context
185
186
    if current_metadata is None:
        # Profiling run.
187
        return
188
189
190
191

    assert current_metadata is not None
    assert isinstance(current_metadata, FlashAttentionMetadata)
    attn_metadata: FlashAttentionMetadata = current_metadata
192
    num_actual_tokens = attn_metadata.num_actual_tokens
193
194
195
196
197

    # Reshape the input keys and values and store them in the cache.
    key_cache = kv_cache[0]
    value_cache = kv_cache[1]
    torch.ops._C_cache_ops.reshape_and_cache_flash(
198
199
200
201
        key[:num_actual_tokens],
        value[:num_actual_tokens],
        key_cache,
        value_cache,
202
203
204
205
206
207
        attn_metadata.slot_mapping,
        kv_cache_dtype,
        k_scale,
        v_scale,
    )

208
209
    # Compute attention and update output up to `num_actual_tokens`.
    flash_attn_varlen_func(
210
        q=query[:num_actual_tokens],
211
212
        k=key_cache,
        v=value_cache,
213
        out=output[:num_actual_tokens],
214
215
216
217
218
219
220
221
222
223
224
225
226
        cu_seqlens_q=attn_metadata.query_start_loc,
        max_seqlen_q=attn_metadata.max_query_len,
        cu_seqlens_k=attn_metadata.seq_start_loc,
        max_seqlen_k=attn_metadata.max_seq_len,
        softmax_scale=softmax_scale,
        causal=True,
        alibi_slopes=alibi_slopes,
        window_size=window_size,
        block_table=attn_metadata.block_table,
        softcap=logits_soft_cap,
    )


Joe Runde's avatar
Joe Runde committed
227
def unified_v1_flash_attention_fake(
228
    output: torch.Tensor,
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    num_heads: int,
    head_size: int,
    num_kv_heads: int,
    kv_cache: torch.Tensor,
    kv_cache_dtype: str,
    k_scale: float,
    v_scale: float,
    softmax_scale: float,
    window_size: Optional[List[int]] = None,
    alibi_slopes: Optional[torch.Tensor] = None,
    logits_soft_cap: Optional[float] = None,
243
244
) -> None:
    return
245
246
247


direct_register_custom_op(
Joe Runde's avatar
Joe Runde committed
248
249
    op_name="unified_v1_flash_attention",
    op_func=unified_v1_flash_attention,
250
    mutates_args=["kv_cache", "output"],
Joe Runde's avatar
Joe Runde committed
251
    fake_impl=unified_v1_flash_attention_fake,
252
)