flashattn_mla.py 11.7 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
5
from typing import ClassVar
6
7
8

import torch

9
10
11
from vllm.attention.backends.abstract import (
    AttentionLayer,
    AttentionType,
12
    MultipleOf,
13
14
15
16
17
18
    is_quantized_kv_cache,
)
from vllm.attention.utils.fa_utils import (
    flash_attn_supports_mla,
    get_flash_attn_version,
)
19
from vllm.config import VllmConfig
20
from vllm.config.cache import CacheDType
21
from vllm.logger import init_logger
22
from vllm.model_executor.layers.batch_invariant import (
23
    vllm_is_batch_invariant,
24
)
25
from vllm.platforms.interface import DeviceCapability
26
27
28
29
30
31
from vllm.v1.attention.backends.mla.common import (
    MLACommonBackend,
    MLACommonDecodeMetadata,
    MLACommonImpl,
    MLACommonMetadata,
    MLACommonMetadataBuilder,
32
    QueryLenSupport,
33
)
34
from vllm.v1.attention.backends.utils import AttentionCGSupport
35
36
37
38
39
40
41
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata

logger = init_logger(__name__)


class FlashAttnMLABackend(MLACommonBackend):
42
43
44
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]

45
46
47
48
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [MultipleOf(16)]

49
50
51
52
53
54
55
56
57
58
59
60
    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN_MLA"

    @staticmethod
    def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
        return FlashAttnMLAMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["FlashAttnMLAImpl"]:
        return FlashAttnMLAImpl

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability.major == 9

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: CacheDType | None,
        block_size: int,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: DeviceCapability,
    ) -> str | None:
        if not flash_attn_supports_mla():
            return "FlashAttention MLA not supported on this device"
        return None

81
82
83
84
85
86

@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
    query_start_loc: torch.Tensor
    max_query_len: int
    max_seq_len: int
87
    scheduler_metadata: torch.Tensor | None = None
88
    max_num_splits: int = 0
89
90
91
92
93
94
95


@dataclass
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
    pass


96
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
97
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
98
99
    query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
    reorder_batch_threshold: int = 512  # process small prefills with decode pathway
100

101
102
103
104
105
106
107
108
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(
109
110
111
112
113
114
            kv_cache_spec,
            layer_names,
            vllm_config,
            device,
            FlashAttnMLAMetadata,
            supports_dcp_with_varlen=True,
115
        )
116
        self.max_num_splits = 0  # No upper bound on the number of splits.
117
        self.fa_aot_schedule = get_flash_attn_version() == 3
118

119
        self.use_full_cuda_graph = (
120
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()
121
        )
122
        self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
123
124
125
126
127
128
129
130
131
132

        if self.use_full_cuda_graph and self.fa_aot_schedule:
            self.scheduler_metadata = torch.zeros(
                vllm_config.scheduler_config.max_num_seqs + 1,
                dtype=torch.int32,
                device=self.device,
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
133
134
135
            self.max_num_splits = (
                vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
            )
136

137
        if vllm_is_batch_invariant():
138
139
            self.max_num_splits = 1

140
    def _schedule_decode(
141
142
143
144
145
146
147
148
        self,
        num_reqs,
        cu_query_lens,
        max_query_len,
        seqlens,
        max_seq_len,
        causal,
        max_num_splits,
149
    ):
150
151
152
153
154
        if self.fa_aot_schedule:
            return get_scheduler_metadata(
                batch_size=num_reqs,
                max_seqlen_q=max_query_len,
                max_seqlen_k=max_seq_len,
155
                num_heads_q=self.num_heads * self.dcp_world_size,
156
157
158
159
160
161
162
163
                num_heads_kv=1,
                headdim=self.mla_dims.qk_rope_head_dim,
                cache_seqlens=seqlens,
                qkv_dtype=self.kv_cache_spec.dtype,
                headdim_v=self.mla_dims.kv_lora_rank,
                page_size=self.page_size,
                cu_seqlens_q=cu_query_lens,
                causal=causal,
164
                num_splits=max_num_splits,
165
166
167
            )
        return None

168
169
170
171
172
173
174
175
    def _build_decode(
        self,
        block_table_tensor: torch.Tensor,
        seq_lens_cpu: torch.Tensor,
        seq_lens_device: torch.Tensor,
        query_start_loc_cpu: torch.Tensor,
        query_start_loc_device: torch.Tensor,
        num_decode_tokens: int,
176
        dcp_tot_seq_lens_device: torch.Tensor | None,
177
178
    ) -> FlashAttnMLADecodeMetadata:
        query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
179
        max_query_len = query_lens_cpu.max().item()
180
        max_seq_len = seq_lens_cpu.max().item()
181

182
183
184
185
186
187
188
189
190
        # For Flash Attention MLA + full cudagraph
        max_num_splits = 0
        if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
            # NOTE(woosuk): Setting num_splits > 1 may increase the memory
            # usage, because the intermediate buffers of size [num_splits,
            # num_heads, num_tokens, head_size] are allocated. Therefore,
            # we only set num_splits when using cuda graphs.
            max_num_splits = self.max_num_splits

191
192
193
        if vllm_is_batch_invariant():
            max_num_splits = 1

194
195
196
197
198
199
200
        scheduler_metadata = self._schedule_decode(
            num_reqs=seq_lens_cpu.numel(),
            cu_query_lens=query_start_loc_device,
            max_query_len=max_query_len,
            seqlens=seq_lens_device,
            max_seq_len=max_seq_len,
            causal=True,
201
            max_num_splits=max_num_splits,
202
203
        )

204
205
206
        if self.use_full_cuda_graph and scheduler_metadata is not None:
            n = scheduler_metadata.shape[0]
            # Ensure the persistent buffer is large enough
207
208
209
210
            assert n <= self.scheduler_metadata.shape[0], (
                f"Scheduler metadata size {n} exceeds buffer size "
                + f"{self.scheduler_metadata.shape[0]}"
            )
211
212
213
214
215
216
217
218
            self.scheduler_metadata[:n] = scheduler_metadata
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

219
        metadata = FlashAttnMLADecodeMetadata(
220
221
222
223
224
225
            block_table=block_table_tensor,
            seq_lens=seq_lens_device,
            query_start_loc=query_start_loc_device,
            max_query_len=max_query_len,
            max_seq_len=max_seq_len,
            scheduler_metadata=scheduler_metadata,
226
            max_num_splits=max_num_splits,
227
            dcp_tot_seq_lens=dcp_tot_seq_lens_device,
228
        )
229
        return metadata
230
231
232


class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
233
    can_return_lse_for_decode: bool = True
234
235

    def __init__(
236
237
238
239
240
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
241
242
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
243
        kv_cache_dtype: str,
244
        logits_soft_cap: float | None,
245
        attn_type: str,
246
        kv_sharing_target_layer_name: str | None,
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )

        assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
265
266
267
268
269

        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashAttnMLAImpl does not support one of the following: "
270
271
                "alibi_slopes, sliding_window, logits_soft_cap"
            )
272
273

        if attn_type != AttentionType.DECODER:
274
275
276
277
278
279
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashAttnMLAImpl"
            )
280
281
282

        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
283
284
                "FlashAttnMLA V1 with FP8 KV cache not yet supported"
            )
285
286
287

    def _forward_decode(
        self,
288
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
289
290
291
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashAttnMLAMetadata,
        layer: AttentionLayer,
292
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
293
294
295
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

296
297
298
299
        if type(q) is tuple:
            q_nope, q_pe = q
        else:
            q_nope, q_pe = torch.split(
300
301
                q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
302

303
        if self.kv_cache_dtype.startswith("fp8"):
304
            raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
305

306
307
        kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
        k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
308

309
310
311
312
313
        # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
        # kernel uses this to calculate grid dimensions. Ensure it's at least 1
        # to prevent invalid grid configuration during graph capture.
        max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)

314
        attn_out = flash_attn_varlen_func(
315
316
317
318
            q=q_pe,
            k=k_pe_cache.unsqueeze(-2),  # Add head dim of 1
            v=kv_c_cache.unsqueeze(-2),  # Add head dim of 1
            q_v=q_nope,
319
            max_seqlen_q=max_seqlen_q,
320
321
322
323
324
325
            cu_seqlens_q=attn_metadata.decode.query_start_loc,
            max_seqlen_k=attn_metadata.decode.max_seq_len,
            seqused_k=attn_metadata.decode.seq_lens,
            block_table=attn_metadata.decode.block_table,
            softmax_scale=self.scale,
            causal=True,
326
            return_softmax_lse=self.need_to_return_lse_for_decode,
327
328
            fa_version=3,  # only version 3 is supported
            scheduler_metadata=attn_metadata.decode.scheduler_metadata,
329
            num_splits=attn_metadata.decode.max_num_splits,
330
331
332
            cp_world_size=self.dcp_world_size,
            cp_rank=self.dcp_rank,
            cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
333
334
        )

335
336
337
338
339
340
341
        if self.need_to_return_lse_for_decode:
            o, lse = attn_out
            # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
            return o, lse.transpose(0, 1)  # [ H, B ] -> [ B, H ]
        else:
            o = attn_out
            return o, None