flashattn_mla.py 11.1 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
5
from typing import ClassVar
6
7
8

import torch

9
from vllm import envs
10
11
12
13
14
15
16
17
18
from vllm.attention.backends.abstract import (
    AttentionLayer,
    AttentionType,
    is_quantized_kv_cache,
)
from vllm.attention.utils.fa_utils import (
    flash_attn_supports_mla,
    get_flash_attn_version,
)
19
20
from vllm.config import VllmConfig
from vllm.logger import init_logger
21
from vllm.model_executor.layers.batch_invariant import (
22
    vllm_is_batch_invariant,
23
)
24
25
26
27
28
29
from vllm.v1.attention.backends.mla.common import (
    MLACommonBackend,
    MLACommonDecodeMetadata,
    MLACommonImpl,
    MLACommonMetadata,
    MLACommonMetadataBuilder,
30
    QueryLenSupport,
31
)
32
from vllm.v1.attention.backends.utils import AttentionCGSupport
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata

logger = init_logger(__name__)


class FlashAttnMLABackend(MLACommonBackend):
    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN_MLA"

    @staticmethod
    def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
        return FlashAttnMLAMetadata

    @staticmethod
    def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
        return FlashAttnMLAMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["FlashAttnMLAImpl"]:
        return FlashAttnMLAImpl


@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
    query_start_loc: torch.Tensor
    max_query_len: int
    max_seq_len: int
62
    scheduler_metadata: torch.Tensor | None = None
63
    max_num_splits: int = 0
64
65
66
67
68
69
70


@dataclass
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
    pass


71
72
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
    cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
73
74
    query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
    reorder_batch_threshold: int = 512  # process small prefills with decode pathway
75

76
77
78
79
80
81
82
83
84
85
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(
            kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata
        )
86
        self.max_num_splits = 0  # No upper bound on the number of splits.
87
        self.fa_aot_schedule = get_flash_attn_version() == 3
88

89
        self.use_full_cuda_graph = (
90
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()
91
        )
92
        self.max_cudagraph_size = self.compilation_config.max_capture_size
93
94
95
96
97
98

        if self.use_full_cuda_graph and self.fa_aot_schedule:
            if self.max_cudagraph_size > 992:
                # This condition derives from FA3's internal heuristic.
                # TODO(woosuk): Support larger cudagraph sizes.
                raise ValueError(
99
100
                    "Capture size larger than 992 is not supported for full cuda graph."
                )
101
102
103
104
105
106
107
108
109

            self.scheduler_metadata = torch.zeros(
                vllm_config.scheduler_config.max_num_seqs + 1,
                dtype=torch.int32,
                device=self.device,
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
110
            self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
111

112
        if vllm_is_batch_invariant():
113
114
            self.max_num_splits = 1

115
    def _schedule_decode(
116
117
118
119
120
121
122
123
        self,
        num_reqs,
        cu_query_lens,
        max_query_len,
        seqlens,
        max_seq_len,
        causal,
        max_num_splits,
124
    ):
125
126
127
128
129
        if self.fa_aot_schedule:
            return get_scheduler_metadata(
                batch_size=num_reqs,
                max_seqlen_q=max_query_len,
                max_seqlen_k=max_seq_len,
130
                num_heads_q=self.num_heads * self.dcp_world_size,
131
132
133
134
135
136
137
138
                num_heads_kv=1,
                headdim=self.mla_dims.qk_rope_head_dim,
                cache_seqlens=seqlens,
                qkv_dtype=self.kv_cache_spec.dtype,
                headdim_v=self.mla_dims.kv_lora_rank,
                page_size=self.page_size,
                cu_seqlens_q=cu_query_lens,
                causal=causal,
139
                num_splits=max_num_splits,
140
141
142
            )
        return None

143
144
145
146
147
148
149
150
    def _build_decode(
        self,
        block_table_tensor: torch.Tensor,
        seq_lens_cpu: torch.Tensor,
        seq_lens_device: torch.Tensor,
        query_start_loc_cpu: torch.Tensor,
        query_start_loc_device: torch.Tensor,
        num_decode_tokens: int,
151
        dcp_tot_seq_lens_device: torch.Tensor | None,
152
153
    ) -> FlashAttnMLADecodeMetadata:
        query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
154
        max_query_len = query_lens_cpu.max().item()
155
        max_seq_len = seq_lens_device.max().item()
156

157
158
159
160
161
162
163
164
165
        # For Flash Attention MLA + full cudagraph
        max_num_splits = 0
        if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
            # NOTE(woosuk): Setting num_splits > 1 may increase the memory
            # usage, because the intermediate buffers of size [num_splits,
            # num_heads, num_tokens, head_size] are allocated. Therefore,
            # we only set num_splits when using cuda graphs.
            max_num_splits = self.max_num_splits

166
167
168
169
170
171
172
        scheduler_metadata = self._schedule_decode(
            num_reqs=seq_lens_cpu.numel(),
            cu_query_lens=query_start_loc_device,
            max_query_len=max_query_len,
            seqlens=seq_lens_device,
            max_seq_len=max_seq_len,
            causal=True,
173
            max_num_splits=max_num_splits,
174
175
        )

176
177
178
        if self.use_full_cuda_graph and scheduler_metadata is not None:
            n = scheduler_metadata.shape[0]
            # Ensure the persistent buffer is large enough
179
180
181
182
            assert n <= self.scheduler_metadata.shape[0], (
                f"Scheduler metadata size {n} exceeds buffer size "
                + f"{self.scheduler_metadata.shape[0]}"
            )
183
184
185
186
187
188
189
190
            self.scheduler_metadata[:n] = scheduler_metadata
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

191
        if vllm_is_batch_invariant():
192
193
194
            max_num_splits = 1

        metadata = FlashAttnMLADecodeMetadata(
195
196
197
198
199
200
            block_table=block_table_tensor,
            seq_lens=seq_lens_device,
            query_start_loc=query_start_loc_device,
            max_query_len=max_query_len,
            max_seq_len=max_seq_len,
            scheduler_metadata=scheduler_metadata,
201
            max_num_splits=max_num_splits,
202
            dcp_tot_seq_lens=dcp_tot_seq_lens_device,
203
        )
204
        return metadata
205
206
207


class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
208
    can_return_lse_for_decode: bool = True
209
210

    def __init__(
211
212
213
214
215
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
216
217
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
218
        kv_cache_dtype: str,
219
        logits_soft_cap: float | None,
220
        attn_type: str,
221
        kv_sharing_target_layer_name: str | None,
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )

        assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
240
241
242
243
244

        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashAttnMLAImpl does not support one of the following: "
245
246
                "alibi_slopes, sliding_window, logits_soft_cap"
            )
247
248

        if attn_type != AttentionType.DECODER:
249
250
251
252
253
254
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashAttnMLAImpl"
            )
255
256
257

        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
258
259
                "FlashAttnMLA V1 with FP8 KV cache not yet supported"
            )
260
261
262

    def _forward_decode(
        self,
263
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
264
265
266
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashAttnMLAMetadata,
        layer: AttentionLayer,
267
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
268
269
270
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

271
272
273
274
        if type(q) is tuple:
            q_nope, q_pe = q
        else:
            q_nope, q_pe = torch.split(
275
276
                q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
277

278
        if self.kv_cache_dtype.startswith("fp8"):
279
            raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
280

281
282
        kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
        k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
283

284
285
286
287
288
        # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
        # kernel uses this to calculate grid dimensions. Ensure it's at least 1
        # to prevent invalid grid configuration during graph capture.
        max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)

289
        attn_out = flash_attn_varlen_func(
290
291
292
293
            q=q_pe,
            k=k_pe_cache.unsqueeze(-2),  # Add head dim of 1
            v=kv_c_cache.unsqueeze(-2),  # Add head dim of 1
            q_v=q_nope,
294
            max_seqlen_q=max_seqlen_q,
295
296
297
298
299
300
            cu_seqlens_q=attn_metadata.decode.query_start_loc,
            max_seqlen_k=attn_metadata.decode.max_seq_len,
            seqused_k=attn_metadata.decode.seq_lens,
            block_table=attn_metadata.decode.block_table,
            softmax_scale=self.scale,
            causal=True,
301
            return_softmax_lse=self.need_to_return_lse_for_decode,
302
303
            fa_version=3,  # only version 3 is supported
            scheduler_metadata=attn_metadata.decode.scheduler_metadata,
304
            num_splits=attn_metadata.decode.max_num_splits,
305
306
307
            cp_world_size=self.dcp_world_size,
            cp_rank=self.dcp_rank,
            cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
308
309
        )

310
311
312
313
314
315
316
        if self.need_to_return_lse_for_decode:
            o, lse = attn_out
            # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
            return o, lse.transpose(0, 1)  # [ H, B ] -> [ B, H ]
        else:
            o = attn_out
            return o, None