flashmla.py 9.65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform
10
from vllm import envs
11
12
13
14
15
16
17
18
19
20
21
22

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

zhuwenwen's avatar
zhuwenwen committed
23
24
25
if current_platform.is_rocm():
    import flash_mla_cuda
    _flashmla_C_AVAILABLE = True
26
27
28
29
30

def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
zhuwenwen's avatar
zhuwenwen committed
31
32
    if not (current_platform.is_cuda() or current_platform.is_rocm()):
        return False, "FlashMLA is supported on CUDA and ROCM devices."
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    if current_platform.get_device_capability()[0] != 9:
        return False, "FlashMLA is only supported on Hopper devices."
    if not _flashmla_C_AVAILABLE:
        return False, "vllm._flashmla_C is not available, likely was not "\
            "compiled due to insufficient nvcc version or a supported arch "\
            "(only sm90a currently) was not in the list of target arches to "\
            "compile for."
    return True, None


def get_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
zhuwenwen's avatar
zhuwenwen committed
47
48
49
    num_heads_q: Optional[int] = None,
    is_fp8_kvcache: bool = False,
    topk: Optional[int] = None
50
51
52
53
54
55
56
57
58
59
60
61
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Return:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), 
                                 dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
zhuwenwen's avatar
zhuwenwen committed
62
63
64
    if current_platform.is_rocm():
        return flash_mla_cuda.get_mla_metadata(cache_seqlens,
                                num_heads_per_head_k,
zhuwenwen's avatar
zhuwenwen committed
65
                                num_heads_k, num_heads_q, is_fp8_kvcache, topk)
zhuwenwen's avatar
zhuwenwen committed
66
67
68
69
    else:
        return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
                                                    num_heads_per_head_k,
                                                    num_heads_k)
70
71


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def get_mla_decoding_metadata_dense_fp8(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Return:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), 
                                 dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens,
                            num_heads_per_head_k,
                            num_heads_k)
        

93
94
95
96
97
98
99
100
101
102
def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
103
104
    k_scale = None,
    kv_cache_dtype = "auto",
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head_dim of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), 
                                 torch.int32, return by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
        softmax_scale: float. The scaling of QK^T before applying softmax. 
                       Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Return:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1]**(-0.5)
zhuwenwen's avatar
zhuwenwen committed
126
    if current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
127
128
        if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":    
            kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype       
zhuwenwen's avatar
zhuwenwen committed
129
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
130
131
132
133
134
135
136
137
138
139
140
                q,
                k_cache,
                None,
                head_dim_v,
                cache_seqlens,
                block_table,
                softmax_scale,
                causal,
                tile_scheduler_metadata,
                num_splits,
                k_scale,
zhuwenwen's avatar
zhuwenwen committed
141
                kv_dtype,
142
143
            )
            return out, softmax_lse
zhuwenwen's avatar
zhuwenwen committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
            q,
            k_cache,
            None,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
        )
    else:
        out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
            q,
            k_cache,
            None,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
        )
169
170
171
    return out, softmax_lse


172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def flash_mla_with_kvcache_q_nope_pe(
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    k_scale = None,
    kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
    if softmax_scale is None:
        softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
    if current_platform.is_rocm():
        if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":    
            kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype   
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
                q_nope,
                q_pe,
                k_cache,
                None,
                head_dim_v,
                cache_seqlens,
                block_table,
                softmax_scale,
                causal,
                tile_scheduler_metadata,
                num_splits,
                k_scale,
                kv_dtype,
            )
            return out, softmax_lse
        out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
            q_nope,
            q_pe,
            k_cache,
            None,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
        )
    return out, softmax_lse


223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

def flash_mla_with_kvcache_fp8(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head_dim of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), 
                                 torch.int32, return by get_mla_decoding_metadata_dense_fp8.
        num_splits: (batch_size + 1), torch.int32, return by get_mla_decoding_metadata_dense_fp8.
        softmax_scale: float. The scaling of QK^T before applying softmax. 
                       Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.
        descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
        descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.

    Return:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """

    if softmax_scale is None:
        softmax_scale = q.shape[-1]**(-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
        descale_q,
        descale_k,
    )
    return out, softmax_lse

276
277
278
279
280
281
282
283
284
285
#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
zhuwenwen's avatar
zhuwenwen committed
286
#