flashmla.py 8.32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
5
from typing import ClassVar, Optional, Union
6
7
8

import torch

9
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
10
11
12
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
                                         get_mla_metadata,
                                         is_flashmla_supported)
13
from vllm.config import VllmConfig
14
from vllm.logger import init_logger
15
from vllm.platforms.cuda import CudaPlatform
16
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
17
                                                   MLACommonDecodeMetadata,
18
19
20
                                                   MLACommonImpl,
                                                   MLACommonMetadata,
                                                   MLACommonMetadataBuilder)
21
from vllm.v1.attention.backends.utils import AttentionCGSupport
22
from vllm.v1.kv_cache_interface import AttentionSpec
23
24
25
26
27
28
29
30
31
32
33

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "FLASHMLA_VLLM_V1"

    @staticmethod
34
    def get_metadata_cls() -> type["FlashMLAMetadata"]:
35
36
37
        return FlashMLAMetadata

    @staticmethod
38
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
39
40
41
        return FlashMLAMetadataBuilder

    @staticmethod
42
    def get_impl_cls() -> type["FlashMLAImpl"]:
43
44
45
46
        return FlashMLAImpl


@dataclass
47
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
48
    tile_scheduler_metadata: torch.Tensor
49
50
51
52
53
54
    num_splits: torch.Tensor


@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass
55
56
57


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
58
59
    cudagraph_support: ClassVar[AttentionCGSupport] = \
        AttentionCGSupport.UNIFORM_BATCH
60

61
62
63
64
    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device,
                         FlashMLAMetadata)
65

66
67
        self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
            vllm_config.parallel_config)
68

69
70
71
        self.cg_buf_tile_scheduler_metadata = None
        self.cg_buf_num_splits = None

72
73
74
        device_properties = torch.cuda.get_device_properties(self.device)
        num_sms = device_properties.multi_processor_count

75
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
76
77
78
79
80
81
82
83
84
85
86
87
            self.cg_buf_tile_scheduler_metadata = torch.zeros(
                # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
                # TileSchedulerMetaDataSize = 8
                (num_sms, 8),
                device=self.device,
                dtype=torch.int32,
            )
            self.cg_buf_num_splits = torch.empty(
                (vllm_config.scheduler_config.max_num_seqs + 1),
                device=self.device,
                dtype=torch.int32)

88
89
90
91
92
93
    def _build_decode(self, block_table_tensor: torch.Tensor,
                      seq_lens_cpu: torch.Tensor,
                      seq_lens_device: torch.Tensor,
                      query_start_loc_cpu: torch.Tensor,
                      query_start_loc_device: torch.Tensor,
                      num_decode_tokens: int) -> FlashMLADecodeMetadata:
94
95
        tile_scheduler_metadata, num_splits = \
            get_mla_metadata(
96
            seq_lens_device,
97
98
99
            self.num_q_heads,
            1, # MQA for the decode path
        )
100

101
102
103
104
        # TODO: we can disambiguate between decode and mixed-prefill decode here
        # so we can only use the persistent buffer if a cudagraph is actually
        # being used.
        if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            assert self.cg_buf_tile_scheduler_metadata is not None
            assert self.cg_buf_num_splits is not None

            sm_parts = tile_scheduler_metadata.size(0)
            # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
            assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
            tile_scheduler_metadata_view = \
                self.cg_buf_tile_scheduler_metadata[:sm_parts]
            tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
            tile_scheduler_metadata = tile_scheduler_metadata_view

            # Num splits is per-batch, varying size (batch_size,)
            n = num_splits.size(0)
            # make sure static buffer is large enough
            assert n <= self.cg_buf_num_splits.size(0)
            num_splits_view = self.cg_buf_num_splits[:n]
            num_splits_view.copy_(num_splits)
            # Num splits needs to monotonically increasing
            # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
            #  it needs to monotonically increasing by 1)
            self.cg_buf_num_splits[n:].fill_(num_splits[-1])
            num_splits = num_splits_view
127

128
        return FlashMLADecodeMetadata(
129
            block_table=block_table_tensor,
130
            seq_lens=seq_lens_device,
131
132
133
            tile_scheduler_metadata=tile_scheduler_metadata,
            num_splits=num_splits,
        )
134
135
136
137


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):

138
139
    can_return_lse_for_decode: bool = True

140
141
142
143
144
145
    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
146
            alibi_slopes: Optional[list[float]],
147
148
149
150
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            logits_soft_cap: Optional[float],
            attn_type: str,
151
            kv_sharing_target_layer_name: Optional[str],
152
153
154
155
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
156
                         logits_soft_cap, attn_type,
157
                         kv_sharing_target_layer_name, **mla_args)
158
159
160
161

        assert is_flashmla_supported(), \
            "FlashMLA is not supported on this device"

162
163
164
165
166
167
168
169
170
171
        # disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs
        # context:
        # https://github.com/deepseek-ai/FlashMLA/issues/83
        # https://github.com/vllm-project/vllm/issues/24513
        if CudaPlatform.has_device_capability(100):
            raise NotImplementedError(
                "FlashMLA is temporarily disabled on Blackwell (SM 10.0). "
                "Please use CUTLASS_MLA or TRITON_MLA instead. "
                "Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`")

172
        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
173
174
175
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
176
                "alibi_slopes, sliding_window, logits_soft_cap")
177
178
179
180
181
182
183
184
185

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashMLAImpl")

    def _forward_decode(
        self,
186
        q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
187
188
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
189
        layer: AttentionLayer,
190
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
191
        assert kv_c_and_k_pe_cache.numel() > 0
192
193
        assert attn_metadata.decode is not None

194
195
        if type(q) is tuple:
            q = torch.cat(q, dim=-1)
196

197
198
199
        assert isinstance(q, torch.Tensor)
        o, lse = flash_mla_with_kvcache(
            q=q.unsqueeze(1),  # Add seqlen dim of 1 (decode)
200
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
201
202
            block_table=attn_metadata.decode.block_table,
            cache_seqlens=attn_metadata.decode.seq_lens,
203
            head_dim_v=self.kv_lora_rank,
204
205
206
            tile_scheduler_metadata=attn_metadata.decode.
            tile_scheduler_metadata,
            num_splits=attn_metadata.decode.num_splits,
207
208
            softmax_scale=self.scale,
            causal=True,
209
210
            descale_q=layer._q_scale.reshape(1),
            descale_k=layer._k_scale.reshape(1),
211
212
        )

213
        return o, lse