flashmla.py 10.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
5
from typing import Any, ClassVar, Optional
6
7
8

import torch

9
10
from vllm.attention.backends.abstract import (AttentionType,
                                              is_quantized_kv_cache)
11
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
12
                                         flash_mla_with_kvcache_q_nope_pe,
13
                                         get_mla_metadata,
14
15
                                         flash_mla_with_kvcache_fp8,
                                         get_mla_decoding_metadata_dense_fp8,
16
17
18
                                         is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
19
                                                   MLACommonDecodeMetadata,
20
21
22
                                                   MLACommonImpl,
                                                   MLACommonMetadata,
                                                   MLACommonMetadataBuilder)
23
24
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
25
from vllm import envs
26

27
28
29
30
31
32
33
34
35
36
37

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "FLASHMLA_VLLM_V1"

    @staticmethod
38
    def get_metadata_cls() -> type["FlashMLAMetadata"]:
39
40
41
        return FlashMLAMetadata

    @staticmethod
42
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
43
44
45
        return FlashMLAMetadataBuilder

    @staticmethod
46
    def get_impl_cls() -> type["FlashMLAImpl"]:
47
48
49
50
        return FlashMLAImpl


@dataclass
51
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
52
    tile_scheduler_metadata: torch.Tensor
53
54
55
56
57
58
    num_splits: torch.Tensor


@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass
59
60
61


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
62
    full_cudagraph_supported: ClassVar[bool] = True  # Decode-only
63

64
65
    def __init__(self, runner, kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
66
        super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
67
68
69
70

        self.num_q_heads = self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config)

71
72
73
        self.cg_buf_tile_scheduler_metadata = None
        self.cg_buf_num_splits = None

74
    def _build_decode(self, block_table_tensor: torch.Tensor,
75
                      seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
76
77
78
79
80
81
        tile_scheduler_metadata, num_splits = \
            get_mla_metadata(
            seq_lens,
            self.num_q_heads,
            1, # MQA for the decode path
        )
82

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        if self.runner.full_cuda_graph:
            # First time around (CUDAGraph capture), allocate the static buffer
            if self.cg_buf_tile_scheduler_metadata is None:
                self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
                self.cg_buf_num_splits = num_splits
            else:
                assert self.cg_buf_num_splits is not None

                # Metadata per-SM, fixed size (#SMs, TileMetadataSize)
                assert (self.cg_buf_tile_scheduler_metadata.size() ==
                        tile_scheduler_metadata.size())
                self.cg_buf_tile_scheduler_metadata.\
                    copy_(tile_scheduler_metadata)
                tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata

                # Num splits is per-batch, varying size (batch_size,)
                n = num_splits.size(0)
                # make sure static buffer is large enough
                assert n <= self.cg_buf_num_splits.size(0)
                num_splits_view = self.cg_buf_num_splits[:n]
                num_splits_view.copy_(num_splits)
                self.cg_buf_num_splits[n:].fill_(0)  # fill the rest with 0s
                num_splits = num_splits_view

107
        return FlashMLADecodeMetadata(
108
            block_table=block_table_tensor,
109
110
111
112
            seq_lens=seq_lens,
            tile_scheduler_metadata=tile_scheduler_metadata,
            num_splits=num_splits,
        )
113
114
115
116
117
118
119
120
121
122


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
123
            alibi_slopes: Optional[list[float]],
124
125
            sliding_window: Optional[int],
            kv_cache_dtype: str,
126
            blocksparse_params: Optional[dict[str, Any]],
127
128
            logits_soft_cap: Optional[float],
            attn_type: str,
129
            kv_sharing_target_layer_name: Optional[str],
130
131
132
133
134
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         blocksparse_params, logits_soft_cap, attn_type,
135
                         kv_sharing_target_layer_name, **mla_args)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

        assert is_flashmla_supported(), \
            "FlashMLA is not supported on this device"

        unsupported_features = [
            alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
        ]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, blocksparse_params, "
                "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashMLAImpl")

155
        if is_quantized_kv_cache(self.kv_cache_dtype):
zhuwenwen's avatar
zhuwenwen committed
156
157
158
159
            if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
                return 
            raise NotImplementedError(
                "FlashMLA with other KV cache not yet supported")
160

161
162
163
164
165
166
    def _forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
167
        q_scale = None,
168
169
        k_scale = None,
        kv_cache_dtype = "auto",        
170
171
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0
172
173
        assert attn_metadata.decode is not None

174
        if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
175
176
177
178
179
180
181
182
            if envs.VLLM_USE_OPT_CAT:
                if q_nope.shape[0] < 1024:
                    from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
                    q = concat_helper_decode(q_nope, q_pe, dim=2)\
                        .unsqueeze(1)
                else:
                    q = torch.cat([q_nope, q_pe], dim=-1)\
                    .unsqueeze(1) # Add seqlen dim of 1 (decode)
zhuwenwen's avatar
zhuwenwen committed
183
184
            else:
                q = torch.cat([q_nope, q_pe], dim=-1)\
185
                    .unsqueeze(1) # Add seqlen dim of 1 (decode)
186
187
                    
            o, _ = flash_mla_with_kvcache_fp8(
zhuwenwen's avatar
zhuwenwen committed
188
                q=q.to(torch.float8_e4m3fn),
189
                k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2),  # Add head dim of 1
190
191
192
193
194
195
196
197
                block_table=attn_metadata.decode.block_table,
                cache_seqlens=attn_metadata.decode.seq_lens,
                head_dim_v=self.kv_lora_rank,
                tile_scheduler_metadata=attn_metadata.decode.
                tile_scheduler_metadata,
                num_splits=attn_metadata.decode.num_splits,
                softmax_scale=self.scale,
                causal=True,
198
199
                descale_q=q_scale,
                descale_k=k_scale,         
200
            )
201
            
202
        else:
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
                if envs.VLLM_USE_OPT_CAT:
                    if q_nope.shape[0] < 1024:
                        from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
                        q = concat_helper_decode(q_nope, q_pe, dim=2)\
                            .unsqueeze(1)
                    else:
                        q = torch.cat([q_nope, q_pe], dim=-1)\
                        .unsqueeze(1) # Add seqlen dim of 1 (decode)
                else:
                    q = torch.cat([q_nope, q_pe], dim=-1)\
                        .unsqueeze(1) # Add seqlen dim of 1 (decode)

            if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
                o, _ = flash_mla_with_kvcache(
                    q=q,
                    k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
                    block_table=attn_metadata.decode.block_table,
                    cache_seqlens=attn_metadata.decode.seq_lens,
                    head_dim_v=self.kv_lora_rank,
                    tile_scheduler_metadata=attn_metadata.decode.
                    tile_scheduler_metadata,
                    num_splits=attn_metadata.decode.num_splits,
                    softmax_scale=self.scale,
                    causal=True,
                    k_scale = k_scale,
                    kv_cache_dtype = kv_cache_dtype,            
                )
            else:
                o, _ = flash_mla_with_kvcache_q_nope_pe(
                    q_nope=q_nope.unsqueeze(1),
                    q_pe=q_pe.unsqueeze(1),
                    k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
                    block_table=attn_metadata.decode.block_table,
                    cache_seqlens=attn_metadata.decode.seq_lens,
                    head_dim_v=self.kv_lora_rank,
                    tile_scheduler_metadata=attn_metadata.decode.
                    tile_scheduler_metadata,
                    num_splits=attn_metadata.decode.num_splits,
                    softmax_scale=self.scale,
                    causal=True,
                    k_scale = k_scale,
                    kv_cache_dtype = kv_cache_dtype,            
                )
247

248
        return self._v_up_proj(o)