flashmla.py 5.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
15

16
17
18
19
20
21
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

22
23
24
25
26
27
28
29
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
    _flashmla_extension_C_AVAILABLE = False
zhuwenwen's avatar
zhuwenwen committed
30
31
32
33
    
if current_platform.is_rocm():
    import flash_mla_cuda
    _flashmla_C_AVAILABLE = True
34

35

36
def _is_flashmla_available() -> tuple[bool, str | None]:
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    if not _flashmla_C_AVAILABLE:
        return (
            False,
            "vllm._flashmla_C is not available, likely was not "
            "compiled due to insufficient nvcc version or a supported arch "
            "was not in the list of target arches to compile for.",
        )
    if not _flashmla_extension_C_AVAILABLE:
        return (
            False,
            "vllm._flashmla_extension_C is not available, likely "
            "was not compiled due to a build error.",
        )

    return True, None


54
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
55
56
57
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
58
59
60
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
61
    if not current_platform.is_device_capability_family(90):
62
63
64
65
        return False, "FlashMLA Dense is only supported on Hopper devices."
    return True, None


66
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
67
68
69
70
71
72
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
73
74
75
76
    if not (
        current_platform.is_device_capability_family(90)
        or current_platform.is_device_capability_family(100)
    ):
77
78
        return (
            False,
79
            "FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
80
        )
81
82
83
    return True, None


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
def _raise_flashmla_unavailable(*_args, **_kwargs):
    _, reason = _is_flashmla_available()
    raise RuntimeError(reason or "FlashMLA is not available")


if _is_flashmla_available()[0]:
    from vllm.third_party.flashmla.flash_mla_interface import (  # noqa: F401
        FlashMLASchedMeta,
        flash_attn_varlen_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
        flash_mla_sparse_fwd,
        flash_mla_with_kvcache,
        get_mla_metadata,
    )
else:

    class FlashMLASchedMeta:  # type: ignore[no-redef]
        pass

    flash_attn_varlen_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_sparse_fwd = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_with_kvcache = _raise_flashmla_unavailable  # type: ignore[assignment]
    get_mla_metadata = _raise_flashmla_unavailable  # type: ignore[assignment]


def get_mla_metadata_dense_fp8(
113
114
115
    cache_seqlens: torch.Tensor,
    num_q_tokens_per_head_k: int,
    num_heads_k: int,
116
    num_heads_q : int = 16,
117
) -> tuple[torch.Tensor, torch.Tensor]:
118
119
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
120
121
122
    
    if current_platform.is_rocm():
        return flash_mla_cuda.flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(
123
124
125
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
126
            num_heads_q,
127
        )
128
129
    else:
        return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
130
131
132
133
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
        )
134

zhuwenwen's avatar
zhuwenwen committed
135

136
def flash_mla_with_kvcache_fp8(
137
138
139
140
141
142
143
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
144
    softmax_scale: float | None = None,
145
    causal: bool = False,
146
147
    descale_q: torch.Tensor | None = None,
    descale_k: torch.Tensor | None = None,
148
) -> tuple[torch.Tensor, torch.Tensor]:
149
150
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
151
    if softmax_scale is None:
152
        softmax_scale = q.shape[-1] ** (-0.5)
153
154
    if current_platform.is_rocm():
        out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
155
156
            q,
            k_cache,
157
            None,
158
159
160
161
162
163
164
165
166
167
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            descale_q,
            descale_k,
        )
168
    else:
169
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
170
171
172
173
174
175
176
177
178
            q,
            k_cache,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
179
180
            descale_q,
            descale_k,
181
        )
182
183
184
    return out, softmax_lse


185
186
187
188
189
190
191
192
193
194
195
#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#