flashmla.py 11.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
5
from typing import ClassVar
6
7
8

import torch

9
import vllm.envs as envs
10
from vllm.config import VllmConfig
11
from vllm.config.cache import CacheDType
12
from vllm.logger import init_logger
13
14
15
16
17
18
19
20
from vllm.model_executor.layers.attention.mla_attention import (
    MLACommonBackend,
    MLACommonDecodeMetadata,
    MLACommonImpl,
    MLACommonMetadata,
    MLACommonMetadataBuilder,
    QueryLenSupport,
)
21
from vllm.platforms.interface import DeviceCapability
22
from vllm.utils.platform_utils import num_compute_units
23
from vllm.utils.torch_utils import is_quantized_kv_cache
24
25
26
27
28
29
from vllm.v1.attention.backend import (
    AttentionCGSupport,
    AttentionLayer,
    AttentionType,
    MultipleOf,
)
30
31
32
from vllm.v1.attention.backends.utils import (
    reshape_attn_output_for_spec_decode,
    reshape_query_for_spec_decode,
33
)
34
from vllm.v1.attention.ops.flashmla import (
35
    FlashMLASchedMeta,
36
    flash_mla_with_kvcache,
37
    flash_mla_with_kvcache_fp8,
38
    get_mla_metadata,
39
    get_mla_metadata_dense_fp8,
40
41
    is_flashmla_dense_supported,
)
42
from vllm.v1.kv_cache_interface import AttentionSpec
43
44
45
46
47

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):
48
49
50
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
51
        "float16",
52
        "bfloat16",
53
54
55
56
        "fp8",
        "fp8_e4m3",
    ]

57
58
59
60
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [64]

61
62
    @staticmethod
    def get_name() -> str:
63
        return "FLASHMLA"
64
65

    @staticmethod
66
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
67
68
69
        return FlashMLAMetadataBuilder

    @staticmethod
70
    def get_impl_cls() -> type["FlashMLAImpl"]:
71
72
        return FlashMLAImpl

73
74
75
76
77
78
79
80
81
82
    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability.major in [9, 10]

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: CacheDType | None,
83
        block_size: int | None,
84
85
86
87
88
89
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: DeviceCapability,
    ) -> str | None:
        if use_sparse:
90
            from vllm.v1.attention.ops.flashmla import is_flashmla_sparse_supported
91
92
93

            return is_flashmla_sparse_supported()[1]
        else:
94
            from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
95
96

            return is_flashmla_dense_supported()[1]
97

98
99

@dataclass
100
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
101
    scheduler_metadata: FlashMLASchedMeta
102
103
104
105
106


@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass
107
108
109


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
110
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
111
    query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
112
    reorder_batch_threshold: int = 128  # process small prefills with decode pathway
113
    # ^ TODO(matt): tune this
114

115
116
117
118
119
120
121
122
123
124
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(
            kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
        )
125

126
        self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
127
128
            vllm_config.parallel_config
        )
129

130
131
        self.cg_buf_tile_scheduler_metadata = None
        self.cg_buf_num_splits = None
132
133
134
        self.is_fp8_kvcache = is_quantized_kv_cache(
            vllm_config.cache_config.cache_dtype
        )
135

136
        num_sms = num_compute_units(self.device.index)
137

138
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
139
140
141
142
143
144
145
146
147
148
            self.cg_buf_tile_scheduler_metadata = torch.zeros(
                # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
                # TileSchedulerMetaDataSize = 8
                (num_sms, 8),
                device=self.device,
                dtype=torch.int32,
            )
            self.cg_buf_num_splits = torch.empty(
                (vllm_config.scheduler_config.max_num_seqs + 1),
                device=self.device,
149
150
151
152
153
154
155
                dtype=torch.int32,
            )

    def _build_decode(
        self,
        block_table_tensor: torch.Tensor,
        seq_lens_device: torch.Tensor,
156
        max_seq_len: int,
157
158
159
        query_start_loc_cpu: torch.Tensor,
        query_start_loc_device: torch.Tensor,
        num_decode_tokens: int,
160
        dcp_tot_seq_lens_device: torch.Tensor | None,
161
    ) -> FlashMLADecodeMetadata:
162
163
164
165
        query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
        # we use the max but all should be the same due to uniform length requirement
        max_query_len = query_lens_cpu.max().item()
        num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
166
        scheduler_metadata, _ = get_mla_metadata(
167
            seq_lens_device,
168
            num_q_tokens_per_head_k,
169
            1,  # MQA for the decode path
170
            is_fp8_kvcache=self.is_fp8_kvcache,
171
        )
172
173
174
175
176
177
        if self.is_fp8_kvcache:
            tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
                seq_lens_device,
                num_q_tokens_per_head_k,
                1,  # MQA for the decode path
            )
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

            # Copy FP8 metadata into persistent CUDA graph buffers
            if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
                assert self.cg_buf_tile_scheduler_metadata is not None
                assert self.cg_buf_num_splits is not None
                n = tile_scheduler_metadata.size(0)
                assert n <= self.cg_buf_tile_scheduler_metadata.size(0)
                self.cg_buf_tile_scheduler_metadata[:n].copy_(tile_scheduler_metadata)
                tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[:n]

                n = num_splits.size(0)
                assert n <= self.cg_buf_num_splits.size(0)
                self.cg_buf_num_splits[:n].copy_(num_splits)
                num_splits = self.cg_buf_num_splits[:n]

193
194
            scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
            scheduler_metadata.num_splits = num_splits
195

196
        return FlashMLADecodeMetadata(
197
            block_table=block_table_tensor,
198
            seq_lens=seq_lens_device,
199
            scheduler_metadata=scheduler_metadata,
200
            dcp_tot_seq_lens=dcp_tot_seq_lens_device,
201
        )
202
203
204


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
205
206
    can_return_lse_for_decode: bool = True

207
    def __init__(
208
209
210
211
212
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
213
214
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
215
        kv_cache_dtype: str,
216
        logits_soft_cap: float | None,
217
        attn_type: str,
218
        kv_sharing_target_layer_name: str | None,
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
235

236
        is_supported, reason = is_flashmla_dense_supported()
237
        assert is_supported, reason
238

239
        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
240
241
242
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
243
244
                "alibi_slopes, sliding_window, logits_soft_cap"
            )
245
246

        if attn_type != AttentionType.DECODER:
247
248
249
250
251
252
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashMLAImpl"
            )
253

254
    def forward_mqa(
255
        self,
256
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
257
258
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
259
        layer: AttentionLayer,
260
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
261
        # TODO: (zyongye) decode function for mla here
262
        assert kv_c_and_k_pe_cache.numel() > 0
263
264
        assert attn_metadata.decode is not None

265
266
        if type(q) is tuple:
            q = torch.cat(q, dim=-1)
267

268
        # mypy assertion: q is now always a tensor
269
        assert isinstance(q, torch.Tensor)
270
271
272
273

        num_decodes = attn_metadata.num_decodes
        q = reshape_query_for_spec_decode(q, num_decodes)

274
        scheduler_metadata = attn_metadata.decode.scheduler_metadata
275
        if envs.VLLM_BATCH_INVARIANT and not is_quantized_kv_cache(self.kv_cache_dtype):
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            device = q.device
            dtype = torch.int32

            B = q.shape[0]
            # block_table shape: [batch_size, max_num_blocks_per_seq]
            # The number of blocks per sequence is in the second dimension
            topk = attn_metadata.decode.block_table.shape[-1]
            B_TOPK = 64
            assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
            end_block_idx = topk // B_TOPK

            # Single partition => num_sm_parts = 1
            # TileSchedulerMetaDataSize = 8, layout:
            # [begin_idx, begin_block_idx, end_idx, end_block_idx,
            #  begin_n_split_idx, _, _, _]
            tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
            tile_scheduler_metadata[0, 0] = 0  # begin_idx
            tile_scheduler_metadata[0, 1] = 0  # sched_begin_block_idx
            tile_scheduler_metadata[0, 2] = B - 1  # end_idx
            tile_scheduler_metadata[0, 3] = end_block_idx
            tile_scheduler_metadata[0, 4] = 0  # begin_n_split_idx
            # fields [5..7] stay 0

            # Non-split path ignores num_splits, but the API requires it:
            # zeros of length B+1
            num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
302
303
304
            scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
            scheduler_metadata.num_splits = num_splits

305
        if is_quantized_kv_cache(self.kv_cache_dtype):
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            o, lse = flash_mla_with_kvcache_fp8(
                q=q,
                k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
                block_table=attn_metadata.decode.block_table,
                cache_seqlens=attn_metadata.decode.seq_lens,
                head_dim_v=self.kv_lora_rank,
                tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
                num_splits=scheduler_metadata.num_splits,
                softmax_scale=self.scale,
                causal=True,
                descale_q=layer._q_scale.reshape(1),
                descale_k=layer._k_scale.reshape(1),
            )
        else:
            o, lse = flash_mla_with_kvcache(
                q=q,
                k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
                block_table=attn_metadata.decode.block_table,
                cache_seqlens=attn_metadata.decode.seq_lens,
                head_dim_v=self.kv_lora_rank,
                tile_scheduler_metadata=scheduler_metadata,
                softmax_scale=self.scale,
                causal=True,
                is_fp8_kvcache=False,
            )
331

332
333
        o = reshape_attn_output_for_spec_decode(o)

334
        return o, lse