flashmla.py 6.53 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
15

16
17
18
19
20
21
22
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
elif current_platform.is_rocm():
    try:
        import flash_mla_cuda  # noqa: F401

23
24
25
26
27
28
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

29
30
31
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
32

33
34
35
36
37
38
39
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
elif current_platform.is_rocm():
    try:
        import flash_mla_cuda  # noqa: F401

40
41
42
43
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
44
    _flashmla_extension_C_AVAILABLE = False
45

46

47
def _is_flashmla_available() -> tuple[bool, str | None]:
48
49
50
51
52
53
54
    if not _flashmla_C_AVAILABLE:
        return (
            False,
            "vllm._flashmla_C is not available, likely was not "
            "compiled due to insufficient nvcc version or a supported arch "
            "was not in the list of target arches to compile for.",
        )
zhuwenwen's avatar
zhuwenwen committed
55
    if not _flashmla_extension_C_AVAILABLE or not current_platform.is_rocm():
56
57
58
59
60
61
62
63
64
        return (
            False,
            "vllm._flashmla_extension_C is not available, likely "
            "was not compiled due to a build error.",
        )

    return True, None


65
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
66
67
68
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
69
70
71
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
72
    if not current_platform.is_device_capability_family(90):
73
74
75
76
        return False, "FlashMLA Dense is only supported on Hopper devices."
    return True, None


77
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
78
79
80
81
82
83
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
84
85
86
87
    if not (
        current_platform.is_device_capability_family(90)
        or current_platform.is_device_capability_family(100)
    ):
88
89
        return (
            False,
90
            "FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
91
        )
92
93
94
    return True, None


95
96
97
98
99
100
def _raise_flashmla_unavailable(*_args, **_kwargs):
    _, reason = _is_flashmla_available()
    raise RuntimeError(reason or "FlashMLA is not available")


if _is_flashmla_available()[0]:
101
102
    if current_platform.is_rocm():
        from flash_mla.flash_mla_interface import (  # noqa: F401
zhuwenwen's avatar
zhuwenwen committed
103
            FlashMLASchedMeta, # need new flashmla
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            # flash_attn_varlen_func,
            # flash_attn_varlen_kvpacked_func,
            # flash_attn_varlen_qkvpacked_func,
            flash_mla_sparse_fwd,
            flash_mla_with_kvcache,
            get_mla_metadata,
        )
    else:
        from vllm.third_party.flashmla.flash_mla_interface import (  # noqa: F401
            FlashMLASchedMeta,
            flash_attn_varlen_func,
            flash_attn_varlen_kvpacked_func,
            flash_attn_varlen_qkvpacked_func,
            flash_mla_sparse_fwd,
            flash_mla_with_kvcache,
            get_mla_metadata,
        )
121
122
123
124
else:

    class FlashMLASchedMeta:  # type: ignore[no-redef]
        pass
zhuwenwen's avatar
zhuwenwen committed
125
    
126
127
128
129
130
131
132
133
134
    flash_attn_varlen_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_sparse_fwd = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_with_kvcache = _raise_flashmla_unavailable  # type: ignore[assignment]
    get_mla_metadata = _raise_flashmla_unavailable  # type: ignore[assignment]


def get_mla_metadata_dense_fp8(
135
136
137
    cache_seqlens: torch.Tensor,
    num_q_tokens_per_head_k: int,
    num_heads_k: int,
138
) -> tuple[torch.Tensor, torch.Tensor]:
139
140
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
zhuwenwen's avatar
zhuwenwen committed
141
    if current_platform.is_rocm():
142
        return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(
143
144
            cache_seqlens,
            num_q_tokens_per_head_k,
145
146
            num_heads_k, 
            16,
147
        )
zhuwenwen's avatar
zhuwenwen committed
148
    else:
149
        return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
150
151
152
153
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
        )
154
155


156
def flash_mla_with_kvcache_fp8(
157
158
159
160
161
162
163
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
164
    softmax_scale: float | None = None,
165
    causal: bool = False,
166
167
    descale_q: torch.Tensor | None = None,
    descale_k: torch.Tensor | None = None,
168
) -> tuple[torch.Tensor, torch.Tensor]:
169
170
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
171
    if softmax_scale is None:
172
        softmax_scale = q.shape[-1] ** (-0.5)
zhuwenwen's avatar
zhuwenwen committed
173
    if current_platform.is_rocm():
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
            q,
            k_cache,
            None,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            descale_q,
            descale_k,
        )
zhuwenwen's avatar
zhuwenwen committed
188
    else:
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
            q,
            k_cache,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            descale_q,
            descale_k,
        )

203
    return out, softmax_lse
204
205
206
207
208
209
210
211
212
213
214
215


#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
zhuwenwen's avatar
zhuwenwen committed
216
#