"docs/backends/vllm/deepseek-r1.md" did not exist on "03360b84756931e13113656711817094fb97799e"
flashmla.py 9.72 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
15

16
17
18
19
20
21
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

22
23
24
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
25

26
27
28
29
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
zhuwenwen's avatar
zhuwenwen committed
30
31
    _flashmla_extension_C_AVAILABLE = True
    _flashmla_extension_C_AVAILABLE = True
32
33
34
35
    
if current_platform.is_rocm():
    import flash_mla_cuda
    _flashmla_C_AVAILABLE = True
36

37

38
def _is_flashmla_available() -> tuple[bool, str | None]:
39
40
41
42
43
44
45
    if not _flashmla_C_AVAILABLE:
        return (
            False,
            "vllm._flashmla_C is not available, likely was not "
            "compiled due to insufficient nvcc version or a supported arch "
            "was not in the list of target arches to compile for.",
        )
zhuwenwen's avatar
zhuwenwen committed
46
    if not _flashmla_extension_C_AVAILABLE or not current_platform.is_rocm():
47
48
49
50
51
52
53
54
55
        return (
            False,
            "vllm._flashmla_extension_C is not available, likely "
            "was not compiled due to a build error.",
        )

    return True, None


56
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
57
58
59
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
60
61
62
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
63
    if current_platform.get_device_capability()[0] != 9:
64
65
66
67
        return False, "FlashMLA Dense is only supported on Hopper devices."
    return True, None


68
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
69
70
71
72
73
74
75
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
    if current_platform.get_device_capability()[0] not in (9, 10):
76
77
        return (
            False,
78
            "FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
79
        )
80
81
82
83
    return True, None


def get_mla_metadata(
84
85
86
    cache_seqlens: torch.Tensor,
    num_q_tokens_per_head_k: int,
    num_heads_k: int,
87
    num_heads_q: int | None = None,
88
    is_fp8_kvcache: bool = False,
89
    topk: int | None = None,
90
) -> tuple[torch.Tensor, torch.Tensor]:
91
92
    """
    Arguments:
93
    - cache_seqlens: (batch_size), dtype torch.int32.
94
    - num_q_tokens_per_head_k:
95
96
            Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
    - num_heads_k: The number of k heads.
97
98
    - num_heads_q:
            The number of q heads.
99
100
            This argument is optional when sparse attention is not enabled
    - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
101
102
    - topk: If not None, sparse attention will be enabled,
            and only tokens in the `indices` array
103
104
105
            passed to `flash_mla_with_kvcache_sm90` will be attended to.

    Returns:
106
    - tile_scheduler_metadata:
107
108
            (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    - num_splits: (batch_size + 1), dtype torch.int32.
109
    """
110
111
112
113
114
115
    if is_fp8_kvcache and topk is None:
        return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
        )
zhuwenwen's avatar
zhuwenwen committed
116
    if current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
117
        return flash_mla_cuda.get_mla_metadata(
118
119
120
121
122
123
124
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
            num_heads_q,
            is_fp8_kvcache,
            topk,
        )
zhuwenwen's avatar
zhuwenwen committed
125
    else:
126
        return torch.ops._flashmla_C.get_mla_decoding_metadata(
127
128
129
130
131
132
133
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
            num_heads_q,
            is_fp8_kvcache,
            topk,
        )
134
135
136
137
138
139
140
141
142
143


def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
144
    softmax_scale: float | None = None,
145
    causal: bool = False,
146
147
    descale_q: torch.Tensor | None = None,
    descale_k: torch.Tensor | None = None,
148
    is_fp8_kvcache: bool = False,
149
    indices: torch.Tensor | None = None,
150
) -> tuple[torch.Tensor, torch.Tensor]:
151
152
    """
    Arguments:
153
154
155
156
157
    - q: (batch_size, seq_len_q, num_heads_q, head_dim).
    - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
    - block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
    - cache_seqlens: (batch_size), torch.int32.
    - head_dim_v: Head dimension of v.
158
159
    - tile_scheduler_metadata:
        (num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
160
        returned by get_mla_metadata.
161
    - num_splits:
162
        (batch_size + 1), torch.int32, returned by get_mla_metadata.
163
164
    - softmax_scale: float.
        The scale of QK^T before applying softmax.
165
166
        Default to 1 / sqrt(head_dim).
    - causal: bool. Whether to apply causal attention mask.
167
    - descale_q: (batch_size),
168
        torch.float32. Descaling factors for Q, used for fp8 quantization.
169
    - descale_k: (batch_size),
170
        torch.float32. Descaling factors for K, used for fp8 quantization.
171
172
    - is_fp8_kvcache: bool.
        Whether the k_cache and v_cache are in fp8 format.
173
        For the format of FP8 KV cache, please refer to README.md
174
175
176
177
    - indices: (batch_size, seq_len_q, topk), torch.int32.
        If not None, sparse attention will be enabled,
        and only tokens in the `indices` array will be attended to.
        Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
178
179
180
181
182
        For details about how to set up `indices`, please refer to README.md.

    Returns:
    - out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
    - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
183
184
    """
    if softmax_scale is None:
185
        softmax_scale = q.shape[-1] ** (-0.5)
186
187
188
189
    if indices is not None:
        # NOTE (zyongye): sparse attention is also causal
        # since it only attend to the tokens before
        # but here `causal` should not be specified
190
191
192
193
        assert not causal, "causal must be `false` if sparse attention is enabled."
    assert (descale_q is None) == (descale_k is None), (
        "descale_q and descale_k should be both None or both not None"
    )
194

195
    if indices is None and q.element_size() == 1:
196
197
198
        # TODO @yangql
        if current_platform.is_rocm():
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
199
200
201
202
203
204
205
206
207
208
209
210
                q, 
                k_cache, 
                None, 
                head_dim_v, 
                cache_seqlens, 
                block_table, 
                softmax_scale,
                causal, 
                tile_scheduler_metadata, 
                num_splits, 
                descale_k, 
                "fp8_e4m3")
211
212
        else:
            out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
213
214
215
216
217
218
219
220
221
222
223
224
                q,
                k_cache,
                head_dim_v,
                cache_seqlens,
                block_table,
                softmax_scale,
                causal,
                tile_scheduler_metadata,
                num_splits,
                descale_q,
                descale_k,
            )
zhuwenwen's avatar
zhuwenwen committed
225
    else:
226
227
        if current_platform.is_rocm():
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
228
229
230
231
232
233
234
235
236
237
238
239
                q,
                k_cache,
                block_table,
                cache_seqlens,
                head_dim_v,
                tile_scheduler_metadata,
                num_splits,
                softmax_scale,
                causal,
                is_fp8_kvcache,
                indices,
            )
240
241
        else:
            out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
242
243
244
245
246
247
248
249
250
251
252
253
                q,
                k_cache,
                head_dim_v,
                cache_seqlens,
                block_table,
                softmax_scale,
                causal,
                tile_scheduler_metadata,
                num_splits,
                is_fp8_kvcache,
                indices,
            )
254
255
256
257
258
259
260
261
262
    return out, softmax_lse


def flash_mla_sparse_prefill(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
263
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
264
265
266
267
268
269
    """
    Sparse attention prefill kernel

    Args:
    - q: [s_q, h_q, d_qk], bfloat16
    - kv: [s_kv, h_kv, d_qk], bfloat16
270
    - indices: [s_q, h_kv, topk], int32.
271
272
273
274
275
276
        Invalid indices should be set to -1 or numbers >= s_kv
    - sm_scale: float
    - d_v: The dimension of value vectors. Can only be 512

    Returns:
    - (output, max_logits, lse)
277
        About the definition of output,
278
279
280
281
282
        max_logits and lse, please refer to README.md
    - output: [s_q, h_q, d_v], bfloat16
    - max_logits:  [s_q, h_q], float
    - lse: [s_q, h_q], float, 2-based log-sum-exp
    """
zhuwenwen's avatar
zhuwenwen committed
283
    if current_platform.is_rocm():
284
        return flash_mla_cuda.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
zhuwenwen's avatar
zhuwenwen committed
285
    else:
286
        results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
287
    return results
288
289
290
291
292
293
294
295
296
297
298
299


#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
zhuwenwen's avatar
zhuwenwen committed
300
#