flashmla.py 10.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
5
from typing import ClassVar
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.config.cache import CacheDType
11
from vllm.logger import init_logger
12
13
14
15
16
17
18
19
from vllm.model_executor.layers.attention.mla_attention import (
    MLACommonBackend,
    MLACommonDecodeMetadata,
    MLACommonImpl,
    MLACommonMetadata,
    MLACommonMetadataBuilder,
    QueryLenSupport,
)
20
from vllm.model_executor.layers.batch_invariant import (
21
    vllm_is_batch_invariant,
22
)
23
from vllm.platforms.interface import DeviceCapability
24
25
26
27
28
29
from vllm.v1.attention.backend import (
    AttentionCGSupport,
    AttentionLayer,
    AttentionType,
    MultipleOf,
)
30
31
32
from vllm.v1.attention.backends.utils import (
    reshape_attn_output_for_spec_decode,
    reshape_query_for_spec_decode,
33
)
34
from vllm.v1.attention.ops.flashmla import (
35
    FlashMLASchedMeta,
36
    flash_mla_with_kvcache,
37
    flash_mla_with_kvcache_fp8,
38
    get_mla_metadata,
39
    get_mla_metadata_dense_fp8,
40
41
    is_flashmla_dense_supported,
)
42
from vllm.v1.kv_cache_interface import AttentionSpec
43
44
45
46
47

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):
48
49
50
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
51
        "bfloat16",
52
53
54
55
        "fp8",
        "fp8_e4m3",
    ]

56
57
58
59
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [64]

60
61
    @staticmethod
    def get_name() -> str:
62
        return "FLASHMLA"
63
64

    @staticmethod
65
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
66
67
68
        return FlashMLAMetadataBuilder

    @staticmethod
69
    def get_impl_cls() -> type["FlashMLAImpl"]:
70
71
        return FlashMLAImpl

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability.major in [9, 10]

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: CacheDType | None,
        block_size: int,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: DeviceCapability,
    ) -> str | None:
        if use_sparse:
89
            from vllm.v1.attention.ops.flashmla import is_flashmla_sparse_supported
90
91
92

            return is_flashmla_sparse_supported()[1]
        else:
93
            from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
94
95

            return is_flashmla_dense_supported()[1]
96

97
98

@dataclass
99
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
100
    scheduler_metadata: FlashMLASchedMeta
101
102
103
104
105


@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass
106
107
108


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
109
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
110
    query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
111
    reorder_batch_threshold: int = 128  # process small prefills with decode pathway
112
    # ^ TODO(matt): tune this
113

114
115
116
117
118
119
120
121
122
123
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(
            kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
        )
124

125
        self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
126
127
            vllm_config.parallel_config
        )
128

129
130
        self.cg_buf_tile_scheduler_metadata = None
        self.cg_buf_num_splits = None
131
        self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
132

133
134
135
        device_properties = torch.cuda.get_device_properties(self.device)
        num_sms = device_properties.multi_processor_count

136
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
137
138
139
140
141
142
143
144
145
146
            self.cg_buf_tile_scheduler_metadata = torch.zeros(
                # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
                # TileSchedulerMetaDataSize = 8
                (num_sms, 8),
                device=self.device,
                dtype=torch.int32,
            )
            self.cg_buf_num_splits = torch.empty(
                (vllm_config.scheduler_config.max_num_seqs + 1),
                device=self.device,
147
148
149
150
151
152
153
                dtype=torch.int32,
            )

    def _build_decode(
        self,
        block_table_tensor: torch.Tensor,
        seq_lens_device: torch.Tensor,
154
        max_seq_len: int,
155
156
157
        query_start_loc_cpu: torch.Tensor,
        query_start_loc_device: torch.Tensor,
        num_decode_tokens: int,
158
        dcp_tot_seq_lens_device: torch.Tensor | None,
159
    ) -> FlashMLADecodeMetadata:
160
161
162
163
        query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
        # we use the max but all should be the same due to uniform length requirement
        max_query_len = query_lens_cpu.max().item()
        num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
164
        scheduler_metadata, _ = get_mla_metadata(
165
            seq_lens_device,
166
            num_q_tokens_per_head_k,
167
            1,  # MQA for the decode path
168
            is_fp8_kvcache=self.is_fp8_kvcache,
169
        )
170
171
172
173
174
175
176
177
        if self.is_fp8_kvcache:
            tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
                seq_lens_device,
                num_q_tokens_per_head_k,
                1,  # MQA for the decode path
            )
            scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
            scheduler_metadata.num_splits = num_splits
178

179
        return FlashMLADecodeMetadata(
180
            block_table=block_table_tensor,
181
            seq_lens=seq_lens_device,
182
            scheduler_metadata=scheduler_metadata,
183
            dcp_tot_seq_lens=dcp_tot_seq_lens_device,
184
        )
185
186
187


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
188
189
    can_return_lse_for_decode: bool = True

190
    def __init__(
191
192
193
194
195
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
196
197
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
198
        kv_cache_dtype: str,
199
        logits_soft_cap: float | None,
200
        attn_type: str,
201
        kv_sharing_target_layer_name: str | None,
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
218

219
        is_supported, reason = is_flashmla_dense_supported()
220
        assert is_supported, reason
221

222
        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
223
224
225
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
226
227
                "alibi_slopes, sliding_window, logits_soft_cap"
            )
228
229

        if attn_type != AttentionType.DECODER:
230
231
232
233
234
235
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashMLAImpl"
            )
236
237
238

    def _forward_decode(
        self,
239
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
240
241
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
242
        layer: AttentionLayer,
243
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
244
        # TODO: (zyongye) decode function for mla here
245
        assert kv_c_and_k_pe_cache.numel() > 0
246
247
        assert attn_metadata.decode is not None

248
249
        if type(q) is tuple:
            q = torch.cat(q, dim=-1)
250

251
        # mypy assertion: q is now always a tensor
252
        assert isinstance(q, torch.Tensor)
253
254
255
256

        num_decodes = attn_metadata.num_decodes
        q = reshape_query_for_spec_decode(q, num_decodes)

257
258
        scheduler_metadata = attn_metadata.decode.scheduler_metadata
        if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            device = q.device
            dtype = torch.int32

            B = q.shape[0]
            # block_table shape: [batch_size, max_num_blocks_per_seq]
            # The number of blocks per sequence is in the second dimension
            topk = attn_metadata.decode.block_table.shape[-1]
            B_TOPK = 64
            assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
            end_block_idx = topk // B_TOPK

            # Single partition => num_sm_parts = 1
            # TileSchedulerMetaDataSize = 8, layout:
            # [begin_idx, begin_block_idx, end_idx, end_block_idx,
            #  begin_n_split_idx, _, _, _]
            tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
            tile_scheduler_metadata[0, 0] = 0  # begin_idx
            tile_scheduler_metadata[0, 1] = 0  # sched_begin_block_idx
            tile_scheduler_metadata[0, 2] = B - 1  # end_idx
            tile_scheduler_metadata[0, 3] = end_block_idx
            tile_scheduler_metadata[0, 4] = 0  # begin_n_split_idx
            # fields [5..7] stay 0

            # Non-split path ignores num_splits, but the API requires it:
            # zeros of length B+1
            num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
            scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
            scheduler_metadata.num_splits = num_splits

        if self.kv_cache_dtype.startswith("fp8"):
            o, lse = flash_mla_with_kvcache_fp8(
                q=q,
                k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
                block_table=attn_metadata.decode.block_table,
                cache_seqlens=attn_metadata.decode.seq_lens,
                head_dim_v=self.kv_lora_rank,
                tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
                num_splits=scheduler_metadata.num_splits,
                softmax_scale=self.scale,
                causal=True,
                descale_q=layer._q_scale.reshape(1),
                descale_k=layer._k_scale.reshape(1),
            )
        else:
            o, lse = flash_mla_with_kvcache(
                q=q,
                k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
                block_table=attn_metadata.decode.block_table,
                cache_seqlens=attn_metadata.decode.seq_lens,
                head_dim_v=self.kv_lora_rank,
                tile_scheduler_metadata=scheduler_metadata,
                softmax_scale=self.scale,
                causal=True,
                is_fp8_kvcache=False,
            )
314

315
316
        o = reshape_attn_output_for_spec_decode(o)

317
        return o, lse