flashmla.py 6.98 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform
9
from vllm.platforms.rocm import get_gcn_arch_name
10
11
12
13
14
logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
15

16
17
18
19
20
21
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

22
23
24
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
25

26
27
28
29
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
30
    _flashmla_extension_C_AVAILABLE = False
zhuwenwen's avatar
zhuwenwen committed
31
32
33
    
    
if current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
34
35
    # import flash_mla.cuda as flash_mla_cuda
    from flash_mla.flash_mla_interface import flash_mla_cuda
zhuwenwen's avatar
zhuwenwen committed
36
37
    _flashmla_C_AVAILABLE = True
    _flashmla_extension_C_AVAILABLE = True
38

39

40
def _is_flashmla_available() -> tuple[bool, str | None]:
41
42
43
44
45
46
47
    if not _flashmla_C_AVAILABLE:
        return (
            False,
            "vllm._flashmla_C is not available, likely was not "
            "compiled due to insufficient nvcc version or a supported arch "
            "was not in the list of target arches to compile for.",
        )
zhuwenwen's avatar
zhuwenwen committed
48
    if not _flashmla_extension_C_AVAILABLE or not current_platform.is_rocm():
49
50
51
52
53
54
55
56
57
        return (
            False,
            "vllm._flashmla_extension_C is not available, likely "
            "was not compiled due to a build error.",
        )

    return True, None


58
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
59
60
61
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
62
63
64
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
65
    if not current_platform.is_device_capability_family(90):
66
67
68
69
        return False, "FlashMLA Dense is only supported on Hopper devices."
    return True, None


70
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
71
72
73
74
75
76
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
77
78
79
80
    if not (
        current_platform.is_device_capability_family(90)
        or current_platform.is_device_capability_family(100)
    ):
81
82
        return (
            False,
83
            "FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
84
        )
85
86
87
    return True, None


88
89
90
91
92
93
def _raise_flashmla_unavailable(*_args, **_kwargs):
    _, reason = _is_flashmla_available()
    raise RuntimeError(reason or "FlashMLA is not available")


if _is_flashmla_available()[0]:
94
95
    if current_platform.is_rocm():
        from flash_mla.flash_mla_interface import (  # noqa: F401
zhuwenwen's avatar
zhuwenwen committed
96
            FlashMLASchedMeta, 
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            # flash_attn_varlen_func,
            # flash_attn_varlen_kvpacked_func,
            # flash_attn_varlen_qkvpacked_func,
            flash_mla_sparse_fwd,
            flash_mla_with_kvcache,
            get_mla_metadata,
        )
    else:
        from vllm.third_party.flashmla.flash_mla_interface import (  # noqa: F401
            FlashMLASchedMeta,
            flash_attn_varlen_func,
            flash_attn_varlen_kvpacked_func,
            flash_attn_varlen_qkvpacked_func,
            flash_mla_sparse_fwd,
            flash_mla_with_kvcache,
            get_mla_metadata,
        )
114
115
116
117
else:

    class FlashMLASchedMeta:  # type: ignore[no-redef]
        pass
zhuwenwen's avatar
zhuwenwen committed
118
    
119
120
121
122
123
124
125
126
127
    flash_attn_varlen_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_sparse_fwd = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_with_kvcache = _raise_flashmla_unavailable  # type: ignore[assignment]
    get_mla_metadata = _raise_flashmla_unavailable  # type: ignore[assignment]


def get_mla_metadata_dense_fp8(
128
129
130
    cache_seqlens: torch.Tensor,
    num_q_tokens_per_head_k: int,
    num_heads_k: int,
131
) -> tuple[torch.Tensor, torch.Tensor]:
132
133
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
zhuwenwen's avatar
zhuwenwen committed
134
    if current_platform.is_rocm():
135
        return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(
136
137
            cache_seqlens,
            num_q_tokens_per_head_k,
138
            num_heads_k, 
139
            # 16,
140
        )
zhuwenwen's avatar
zhuwenwen committed
141
    else:
142
        return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
143
144
145
146
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
        )
147
148


149
def flash_mla_with_kvcache_fp8(
150
151
152
153
154
155
156
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
157
    softmax_scale: float | None = None,
158
    causal: bool = False,
159
160
    descale_q: torch.Tensor | None = None,
    descale_k: torch.Tensor | None = None,
161
    kv_cache_dtype: str | None = None,
162
) -> tuple[torch.Tensor, torch.Tensor]:
163
164
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
165
    if softmax_scale is None:
166
        softmax_scale = q.shape[-1] ** (-0.5)
zhuwenwen's avatar
zhuwenwen committed
167
    if current_platform.is_rocm():
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        if get_gcn_arch_name() == "gfx938":
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
                q,
                k_cache,
                None,
                head_dim_v,
                cache_seqlens,
                block_table,
                softmax_scale,
                causal,
                tile_scheduler_metadata,
                num_splits,
                descale_q,
                descale_k,
            )
        else:
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
                q,
                k_cache,
                None,
                head_dim_v,
                cache_seqlens,
                block_table,
                softmax_scale,
                causal,
                tile_scheduler_metadata,
                num_splits,
                descale_k,
                kv_cache_dtype,
            )             
zhuwenwen's avatar
zhuwenwen committed
198
    else:
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
            q,
            k_cache,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            descale_q,
            descale_k,
        )

213
    return out, softmax_lse
214
215
216
217
218
219
220
221
222
223
224
225


#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
zhuwenwen's avatar
zhuwenwen committed
226
#