flashmla.py 7.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
5
from typing import Any, ClassVar, Optional
6
7
8

import torch

9
10
from vllm.attention.backends.abstract import (AttentionType,
                                              is_quantized_kv_cache)
11
12
13
14
15
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
                                         get_mla_metadata,
                                         is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
16
                                                   MLACommonDecodeMetadata,
17
18
19
                                                   MLACommonImpl,
                                                   MLACommonMetadata,
                                                   MLACommonMetadataBuilder)
20
21
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
22
from vllm import envs
23

24
25
26
27
28
29
30
31
32
33
34

logger = init_logger(__name__)


class FlashMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "FLASHMLA_VLLM_V1"

    @staticmethod
35
    def get_metadata_cls() -> type["FlashMLAMetadata"]:
36
37
38
        return FlashMLAMetadata

    @staticmethod
39
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
40
41
42
        return FlashMLAMetadataBuilder

    @staticmethod
43
    def get_impl_cls() -> type["FlashMLAImpl"]:
44
45
46
47
        return FlashMLAImpl


@dataclass
48
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
49
    tile_scheduler_metadata: torch.Tensor
50
51
52
53
54
55
    num_splits: torch.Tensor


@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass
56
57
58


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
59
    full_cudagraph_supported: ClassVar[bool] = True  # Decode-only
60

61
62
    def __init__(self, runner, kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
63
        super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
64
65
66
67

        self.num_q_heads = self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config)

68
69
70
        self.cg_buf_tile_scheduler_metadata = None
        self.cg_buf_num_splits = None

71
    def _build_decode(self, block_table_tensor: torch.Tensor,
72
73
74
75
76
77
78
                      seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
        tile_scheduler_metadata, num_splits = \
            get_mla_metadata(
            seq_lens,
            self.num_q_heads,
            1, # MQA for the decode path
        )
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        if self.runner.full_cuda_graph:
            # First time around (CUDAGraph capture), allocate the static buffer
            if self.cg_buf_tile_scheduler_metadata is None:
                self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
                self.cg_buf_num_splits = num_splits
            else:
                assert self.cg_buf_num_splits is not None

                # Metadata per-SM, fixed size (#SMs, TileMetadataSize)
                assert (self.cg_buf_tile_scheduler_metadata.size() ==
                        tile_scheduler_metadata.size())
                self.cg_buf_tile_scheduler_metadata.\
                    copy_(tile_scheduler_metadata)
                tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata

                # Num splits is per-batch, varying size (batch_size,)
                n = num_splits.size(0)
                # make sure static buffer is large enough
                assert n <= self.cg_buf_num_splits.size(0)
                num_splits_view = self.cg_buf_num_splits[:n]
                num_splits_view.copy_(num_splits)
                self.cg_buf_num_splits[n:].fill_(0)  # fill the rest with 0s
                num_splits = num_splits_view

104
        return FlashMLADecodeMetadata(
105
            block_table=block_table_tensor,
106
107
108
109
            seq_lens=seq_lens,
            tile_scheduler_metadata=tile_scheduler_metadata,
            num_splits=num_splits,
        )
110
111
112
113
114
115
116
117
118
119


class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
120
            alibi_slopes: Optional[list[float]],
121
122
            sliding_window: Optional[int],
            kv_cache_dtype: str,
123
            blocksparse_params: Optional[dict[str, Any]],
124
125
            logits_soft_cap: Optional[float],
            attn_type: str,
126
            kv_sharing_target_layer_name: Optional[str],
127
128
129
130
131
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         blocksparse_params, logits_soft_cap, attn_type,
132
                         kv_sharing_target_layer_name, **mla_args)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

        assert is_flashmla_supported(), \
            "FlashMLA is not supported on this device"

        unsupported_features = [
            alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
        ]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, blocksparse_params, "
                "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashMLAImpl")

152
        if is_quantized_kv_cache(self.kv_cache_dtype):
zhuwenwen's avatar
zhuwenwen committed
153
154
155
156
            if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
                return 
            raise NotImplementedError(
                "FlashMLA with other KV cache not yet supported")
157

158
159
160
161
162
163
    def _forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
164
165
        k_scale = None,
        kv_cache_dtype = "auto",        
166
167
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0
168
169
        assert attn_metadata.decode is not None

170
171
        if envs.VLLM_USE_OPT_CAT:
            if q_nope.shape[0] < 1024:
zhuwenwen's avatar
zhuwenwen committed
172
                from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
173
                q = concat_helper_decode(q_nope, q_pe, dim=2)\
zhuwenwen's avatar
zhuwenwen committed
174
175
176
177
                    .unsqueeze(1)
            else:
                q = torch.cat([q_nope, q_pe], dim=-1)\
                .unsqueeze(1) # Add seqlen dim of 1 (decode)
178
179
180
        else:
            q = torch.cat([q_nope, q_pe], dim=-1)\
                .unsqueeze(1) # Add seqlen dim of 1 (decode)
181
182
183
184

        o, _ = flash_mla_with_kvcache(
            q=q,
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
185
186
            block_table=attn_metadata.decode.block_table,
            cache_seqlens=attn_metadata.decode.seq_lens,
187
            head_dim_v=self.kv_lora_rank,
188
189
190
            tile_scheduler_metadata=attn_metadata.decode.
            tile_scheduler_metadata,
            num_splits=attn_metadata.decode.num_splits,
191
192
            softmax_scale=self.scale,
            causal=True,
193
194
            k_scale = k_scale,
            kv_cache_dtype = kv_cache_dtype,            
195
196
        )

197
        return self._v_up_proj(o)