flashmla.py 6.35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
15

16
17
18
19
20
21
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

22
23
24
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
25

26
27
28
29
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
30
    _flashmla_extension_C_AVAILABLE = False
zhuwenwen's avatar
zhuwenwen committed
31
32
33
    
    
if current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
34
35
    # import flash_mla.cuda as flash_mla_cuda
    from flash_mla.flash_mla_interface import flash_mla_cuda
zhuwenwen's avatar
zhuwenwen committed
36
37
    _flashmla_C_AVAILABLE = True
    _flashmla_extension_C_AVAILABLE = True
38

39

40
def _is_flashmla_available() -> tuple[bool, str | None]:
41
42
43
44
45
46
47
    if not _flashmla_C_AVAILABLE:
        return (
            False,
            "vllm._flashmla_C is not available, likely was not "
            "compiled due to insufficient nvcc version or a supported arch "
            "was not in the list of target arches to compile for.",
        )
zhuwenwen's avatar
zhuwenwen committed
48
    if not _flashmla_extension_C_AVAILABLE or not current_platform.is_rocm():
49
50
51
52
53
54
55
56
57
        return (
            False,
            "vllm._flashmla_extension_C is not available, likely "
            "was not compiled due to a build error.",
        )

    return True, None


58
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
59
60
61
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
62
63
64
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
65
    if not current_platform.is_device_capability_family(90):
66
67
68
69
        return False, "FlashMLA Dense is only supported on Hopper devices."
    return True, None


70
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
71
72
73
74
75
76
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
77
78
79
80
    if not (
        current_platform.is_device_capability_family(90)
        or current_platform.is_device_capability_family(100)
    ):
81
82
        return (
            False,
83
            "FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
84
        )
85
86
87
    return True, None


88
89
90
91
92
93
def _raise_flashmla_unavailable(*_args, **_kwargs):
    _, reason = _is_flashmla_available()
    raise RuntimeError(reason or "FlashMLA is not available")


if _is_flashmla_available()[0]:
94
95
    if current_platform.is_rocm():
        from flash_mla.flash_mla_interface import (  # noqa: F401
zhuwenwen's avatar
zhuwenwen committed
96
            FlashMLASchedMeta, 
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            # flash_attn_varlen_func,
            # flash_attn_varlen_kvpacked_func,
            # flash_attn_varlen_qkvpacked_func,
            flash_mla_sparse_fwd,
            flash_mla_with_kvcache,
            get_mla_metadata,
        )
    else:
        from vllm.third_party.flashmla.flash_mla_interface import (  # noqa: F401
            FlashMLASchedMeta,
            flash_attn_varlen_func,
            flash_attn_varlen_kvpacked_func,
            flash_attn_varlen_qkvpacked_func,
            flash_mla_sparse_fwd,
            flash_mla_with_kvcache,
            get_mla_metadata,
        )
114
115
116
117
else:

    class FlashMLASchedMeta:  # type: ignore[no-redef]
        pass
zhuwenwen's avatar
zhuwenwen committed
118
    
119
120
121
122
123
124
125
126
127
    flash_attn_varlen_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_sparse_fwd = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_with_kvcache = _raise_flashmla_unavailable  # type: ignore[assignment]
    get_mla_metadata = _raise_flashmla_unavailable  # type: ignore[assignment]


def get_mla_metadata_dense_fp8(
128
129
130
    cache_seqlens: torch.Tensor,
    num_q_tokens_per_head_k: int,
    num_heads_k: int,
131
) -> tuple[torch.Tensor, torch.Tensor]:
132
133
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
zhuwenwen's avatar
zhuwenwen committed
134
    if current_platform.is_rocm():
135
        return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(
136
137
            cache_seqlens,
            num_q_tokens_per_head_k,
138
139
            num_heads_k, 
            16,
140
        )
zhuwenwen's avatar
zhuwenwen committed
141
    else:
142
        return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
143
144
145
146
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
        )
147
148


149
def flash_mla_with_kvcache_fp8(
150
151
152
153
154
155
156
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
157
    softmax_scale: float | None = None,
158
    causal: bool = False,
159
160
    descale_q: torch.Tensor | None = None,
    descale_k: torch.Tensor | None = None,
161
) -> tuple[torch.Tensor, torch.Tensor]:
162
163
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
164
    if softmax_scale is None:
165
        softmax_scale = q.shape[-1] ** (-0.5)
zhuwenwen's avatar
zhuwenwen committed
166
    if current_platform.is_rocm():
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
            q,
            k_cache,
            None,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            descale_q,
            descale_k,
        )
zhuwenwen's avatar
zhuwenwen committed
181
    else:
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
            q,
            k_cache,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            descale_q,
            descale_k,
        )

196
    return out, softmax_lse
197
198
199
200
201
202
203
204
205
206
207
208


#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
zhuwenwen's avatar
zhuwenwen committed
209
#