flashmla.py 9.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
5
from typing import ClassVar
6
7
8

import torch

9
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
10
11
12
from vllm.attention.ops.flashmla import (
    flash_mla_with_kvcache,
    get_mla_metadata,
13
    is_flashmla_dense_supported,
14
)
15
from vllm.config import VllmConfig
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
20
21
22
23
24
25
from vllm.v1.attention.backends.mla.common import (
    MLACommonBackend,
    MLACommonDecodeMetadata,
    MLACommonImpl,
    MLACommonMetadata,
    MLACommonMetadataBuilder,
26
27
28
29
30
31
    QueryLenSupport,
)
from vllm.v1.attention.backends.utils import (
    AttentionCGSupport,
    reshape_attn_output_for_spec_decode,
    reshape_query_for_spec_decode,
32
)
33
from vllm.v1.kv_cache_interface import AttentionSpec
34
35
36
37
38
39
40

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):
    @staticmethod
    def get_name() -> str:
41
        return "FLASHMLA"
42
43

    @staticmethod
44
    def get_metadata_cls() -> type["FlashMLAMetadata"]:
45
46
47
        return FlashMLAMetadata

    @staticmethod
48
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
49
50
51
        return FlashMLAMetadataBuilder

    @staticmethod
52
    def get_impl_cls() -> type["FlashMLAImpl"]:
53
54
        return FlashMLAImpl

55
    @staticmethod
56
    def get_supported_kernel_block_size() -> list[int | MultipleOf]:
57
58
        return [64]

59
60

@dataclass
61
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
62
    tile_scheduler_metadata: torch.Tensor
63
64
65
66
67
68
    num_splits: torch.Tensor


@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass
69
70
71


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
72
    cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
73
74
75
    query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
    reorder_batch_threshold: int = 512  # process small prefills with decode pathway
    # ^ TODO(matt): tune this
76

77
78
79
80
81
82
83
84
85
86
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(
            kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
        )
87

88
        self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
89
90
            vllm_config.parallel_config
        )
91

92
93
94
        self.cg_buf_tile_scheduler_metadata = None
        self.cg_buf_num_splits = None

95
96
97
        device_properties = torch.cuda.get_device_properties(self.device)
        num_sms = device_properties.multi_processor_count

98
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
99
100
101
102
103
104
105
106
107
108
            self.cg_buf_tile_scheduler_metadata = torch.zeros(
                # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
                # TileSchedulerMetaDataSize = 8
                (num_sms, 8),
                device=self.device,
                dtype=torch.int32,
            )
            self.cg_buf_num_splits = torch.empty(
                (vllm_config.scheduler_config.max_num_seqs + 1),
                device=self.device,
109
110
111
112
113
114
115
116
117
118
119
                dtype=torch.int32,
            )

    def _build_decode(
        self,
        block_table_tensor: torch.Tensor,
        seq_lens_cpu: torch.Tensor,
        seq_lens_device: torch.Tensor,
        query_start_loc_cpu: torch.Tensor,
        query_start_loc_device: torch.Tensor,
        num_decode_tokens: int,
120
        dcp_tot_seq_lens_device: torch.Tensor | None,
121
122
    ) -> FlashMLADecodeMetadata:
        tile_scheduler_metadata, num_splits = get_mla_metadata(
123
            seq_lens_device,
124
            self.num_q_heads,
125
            1,  # MQA for the decode path
126
        )
127

128
129
130
131
        # TODO: we can disambiguate between decode and mixed-prefill decode here
        # so we can only use the persistent buffer if a cudagraph is actually
        # being used.
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
132
133
134
135
136
137
            assert self.cg_buf_tile_scheduler_metadata is not None
            assert self.cg_buf_num_splits is not None

            sm_parts = tile_scheduler_metadata.size(0)
            # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
            assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
138
139
140
            tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
                :sm_parts
            ]
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
            tile_scheduler_metadata = tile_scheduler_metadata_view

            # Num splits is per-batch, varying size (batch_size,)
            n = num_splits.size(0)
            # make sure static buffer is large enough
            assert n <= self.cg_buf_num_splits.size(0)
            num_splits_view = self.cg_buf_num_splits[:n]
            num_splits_view.copy_(num_splits)
            # Num splits needs to monotonically increasing
            # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
            #  it needs to monotonically increasing by 1)
            self.cg_buf_num_splits[n:].fill_(num_splits[-1])
            num_splits = num_splits_view
155

156
        return FlashMLADecodeMetadata(
157
            block_table=block_table_tensor,
158
            seq_lens=seq_lens_device,
159
160
            tile_scheduler_metadata=tile_scheduler_metadata,
            num_splits=num_splits,
161
            dcp_tot_seq_lens=dcp_tot_seq_lens_device,
162
        )
163
164
165


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
166
167
    can_return_lse_for_decode: bool = True

168
    def __init__(
169
170
171
172
173
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
174
175
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
176
        kv_cache_dtype: str,
177
        logits_soft_cap: float | None,
178
        attn_type: str,
179
        kv_sharing_target_layer_name: str | None,
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
196

197
        is_supported, reason = is_flashmla_dense_supported()
198
        assert is_supported, reason
199

200
        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
201
202
203
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
204
205
                "alibi_slopes, sliding_window, logits_soft_cap"
            )
206
207

        if attn_type != AttentionType.DECODER:
208
209
210
211
212
213
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashMLAImpl"
            )
214
215
216

    def _forward_decode(
        self,
217
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
218
219
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
220
        layer: AttentionLayer,
221
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
222
        # TODO: (zyongye) decode function for mla here
223
        assert kv_c_and_k_pe_cache.numel() > 0
224
225
        assert attn_metadata.decode is not None

226
227
        if type(q) is tuple:
            q = torch.cat(q, dim=-1)
228

229
        # mypy assertion: q is now always a tensor
230
        assert isinstance(q, torch.Tensor)
231
232
233
234

        num_decodes = attn_metadata.num_decodes
        q = reshape_query_for_spec_decode(q, num_decodes)

235
236
        tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
        num_splits = attn_metadata.decode.num_splits
237
        if vllm_is_batch_invariant():
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            device = q.device
            dtype = torch.int32

            B = q.shape[0]
            # block_table shape: [batch_size, max_num_blocks_per_seq]
            # The number of blocks per sequence is in the second dimension
            topk = attn_metadata.decode.block_table.shape[-1]
            B_TOPK = 64
            assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
            end_block_idx = topk // B_TOPK

            # Single partition => num_sm_parts = 1
            # TileSchedulerMetaDataSize = 8, layout:
            # [begin_idx, begin_block_idx, end_idx, end_block_idx,
            #  begin_n_split_idx, _, _, _]
            tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
            tile_scheduler_metadata[0, 0] = 0  # begin_idx
            tile_scheduler_metadata[0, 1] = 0  # sched_begin_block_idx
            tile_scheduler_metadata[0, 2] = B - 1  # end_idx
            tile_scheduler_metadata[0, 3] = end_block_idx
            tile_scheduler_metadata[0, 4] = 0  # begin_n_split_idx
            # fields [5..7] stay 0

            # Non-split path ignores num_splits, but the API requires it:
            # zeros of length B+1
            num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)

265
        o, lse = flash_mla_with_kvcache(
266
            q=q,
267
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
268
269
            block_table=attn_metadata.decode.block_table,
            cache_seqlens=attn_metadata.decode.seq_lens,
270
            head_dim_v=self.kv_lora_rank,
271
272
            tile_scheduler_metadata=tile_scheduler_metadata,
            num_splits=num_splits,
273
274
            softmax_scale=self.scale,
            causal=True,
275
276
            descale_q=layer._q_scale.reshape(1),
            descale_k=layer._k_scale.reshape(1),
277
278
        )

279
280
        o = reshape_attn_output_for_spec_decode(o)

281
        return o, lse