flashmla.py 7.75 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
5
from typing import ClassVar
6
7
8

import torch

9
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
10
11
12
from vllm.attention.ops.flashmla import (
    flash_mla_with_kvcache,
    get_mla_metadata,
13
    is_flashmla_dense_supported,
14
)
15
from vllm.config import VllmConfig
16
from vllm.logger import init_logger
17
18
19
20
21
22
23
from vllm.v1.attention.backends.mla.common import (
    MLACommonBackend,
    MLACommonDecodeMetadata,
    MLACommonImpl,
    MLACommonMetadata,
    MLACommonMetadataBuilder,
)
24
from vllm.v1.attention.backends.utils import AttentionCGSupport
25
from vllm.v1.kv_cache_interface import AttentionSpec
26
27
28
29
30
31
32

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):
    @staticmethod
    def get_name() -> str:
33
        return "FLASHMLA"
34
35

    @staticmethod
36
    def get_metadata_cls() -> type["FlashMLAMetadata"]:
37
38
39
        return FlashMLAMetadata

    @staticmethod
40
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
41
42
43
        return FlashMLAMetadataBuilder

    @staticmethod
44
    def get_impl_cls() -> type["FlashMLAImpl"]:
45
46
        return FlashMLAImpl

47
    @staticmethod
48
    def get_supported_kernel_block_size() -> list[int | MultipleOf]:
49
50
        return [64]

51
52

@dataclass
53
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
54
    tile_scheduler_metadata: torch.Tensor
55
56
57
58
59
60
    num_splits: torch.Tensor


@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass
61
62
63


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
64
    cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
65

66
67
68
69
70
71
72
73
74
75
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(
            kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
        )
76

77
        self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
78
79
            vllm_config.parallel_config
        )
80

81
82
83
        self.cg_buf_tile_scheduler_metadata = None
        self.cg_buf_num_splits = None

84
85
86
        device_properties = torch.cuda.get_device_properties(self.device)
        num_sms = device_properties.multi_processor_count

87
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
88
89
90
91
92
93
94
95
96
97
            self.cg_buf_tile_scheduler_metadata = torch.zeros(
                # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
                # TileSchedulerMetaDataSize = 8
                (num_sms, 8),
                device=self.device,
                dtype=torch.int32,
            )
            self.cg_buf_num_splits = torch.empty(
                (vllm_config.scheduler_config.max_num_seqs + 1),
                device=self.device,
98
99
100
101
102
103
104
105
106
107
108
                dtype=torch.int32,
            )

    def _build_decode(
        self,
        block_table_tensor: torch.Tensor,
        seq_lens_cpu: torch.Tensor,
        seq_lens_device: torch.Tensor,
        query_start_loc_cpu: torch.Tensor,
        query_start_loc_device: torch.Tensor,
        num_decode_tokens: int,
109
        dcp_tot_seq_lens_device: torch.Tensor | None,
110
111
    ) -> FlashMLADecodeMetadata:
        tile_scheduler_metadata, num_splits = get_mla_metadata(
112
            seq_lens_device,
113
            self.num_q_heads,
114
            1,  # MQA for the decode path
115
        )
116

117
118
119
120
        # TODO: we can disambiguate between decode and mixed-prefill decode here
        # so we can only use the persistent buffer if a cudagraph is actually
        # being used.
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
121
122
123
124
125
126
            assert self.cg_buf_tile_scheduler_metadata is not None
            assert self.cg_buf_num_splits is not None

            sm_parts = tile_scheduler_metadata.size(0)
            # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
            assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
127
128
129
            tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
                :sm_parts
            ]
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
            tile_scheduler_metadata = tile_scheduler_metadata_view

            # Num splits is per-batch, varying size (batch_size,)
            n = num_splits.size(0)
            # make sure static buffer is large enough
            assert n <= self.cg_buf_num_splits.size(0)
            num_splits_view = self.cg_buf_num_splits[:n]
            num_splits_view.copy_(num_splits)
            # Num splits needs to monotonically increasing
            # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
            #  it needs to monotonically increasing by 1)
            self.cg_buf_num_splits[n:].fill_(num_splits[-1])
            num_splits = num_splits_view
144

145
        return FlashMLADecodeMetadata(
146
            block_table=block_table_tensor,
147
            seq_lens=seq_lens_device,
148
149
            tile_scheduler_metadata=tile_scheduler_metadata,
            num_splits=num_splits,
150
            dcp_tot_seq_lens=dcp_tot_seq_lens_device,
151
        )
152
153
154


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
155
156
    can_return_lse_for_decode: bool = True

157
    def __init__(
158
159
160
161
162
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
163
164
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
165
        kv_cache_dtype: str,
166
        logits_soft_cap: float | None,
167
        attn_type: str,
168
        kv_sharing_target_layer_name: str | None,
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **mla_args,
        )
185

186
        is_supported, reason = is_flashmla_dense_supported()
187
        assert is_supported, reason
188

189
        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
190
191
192
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
193
194
                "alibi_slopes, sliding_window, logits_soft_cap"
            )
195
196

        if attn_type != AttentionType.DECODER:
197
198
199
200
201
202
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashMLAImpl"
            )
203
204
205

    def _forward_decode(
        self,
206
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
207
208
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
209
        layer: AttentionLayer,
210
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
211
        # TODO: (zyongye) decode function for mla here
212
        assert kv_c_and_k_pe_cache.numel() > 0
213
214
        assert attn_metadata.decode is not None

215
216
        if type(q) is tuple:
            q = torch.cat(q, dim=-1)
217

218
219
220
        assert isinstance(q, torch.Tensor)
        o, lse = flash_mla_with_kvcache(
            q=q.unsqueeze(1),  # Add seqlen dim of 1 (decode)
221
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
222
223
            block_table=attn_metadata.decode.block_table,
            cache_seqlens=attn_metadata.decode.seq_lens,
224
            head_dim_v=self.kv_lora_rank,
225
            tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata,
226
            num_splits=attn_metadata.decode.num_splits,
227
228
            softmax_scale=self.scale,
            causal=True,
229
230
            descale_q=layer._q_scale.reshape(1),
            descale_k=layer._k_scale.reshape(1),
231
232
        )

233
        return o, lse