flashmla.py 5.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
15

16
17
18
19
20
21
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

22
23
24
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
zhuwenwen's avatar
zhuwenwen committed
25
        
26
27
28
29
30
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
    _flashmla_extension_C_AVAILABLE = False
zhuwenwen's avatar
zhuwenwen committed
31
32
33
34
    
if current_platform.is_rocm():
    import flash_mla_cuda
    _flashmla_C_AVAILABLE = True
35

36

37
def _is_flashmla_available() -> tuple[bool, str | None]:
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    if not _flashmla_C_AVAILABLE:
        return (
            False,
            "vllm._flashmla_C is not available, likely was not "
            "compiled due to insufficient nvcc version or a supported arch "
            "was not in the list of target arches to compile for.",
        )
    if not _flashmla_extension_C_AVAILABLE:
        return (
            False,
            "vllm._flashmla_extension_C is not available, likely "
            "was not compiled due to a build error.",
        )

    return True, None


55
def is_flashmla_dense_supported() -> tuple[bool, str | None]:
56
57
58
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
59
60
61
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
62
    if not current_platform.is_device_capability_family(90):
63
64
65
66
        return False, "FlashMLA Dense is only supported on Hopper devices."
    return True, None


67
def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
68
69
70
71
72
73
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    is_availble, maybe_reason = _is_flashmla_available()
    if not is_availble:
        return False, maybe_reason
74
75
76
77
    if not (
        current_platform.is_device_capability_family(90)
        or current_platform.is_device_capability_family(100)
    ):
78
79
        return (
            False,
80
            "FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
81
        )
82
83
84
    return True, None


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def _raise_flashmla_unavailable(*_args, **_kwargs):
    _, reason = _is_flashmla_available()
    raise RuntimeError(reason or "FlashMLA is not available")


if _is_flashmla_available()[0]:
    from vllm.third_party.flashmla.flash_mla_interface import (  # noqa: F401
        FlashMLASchedMeta,
        flash_attn_varlen_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
        flash_mla_sparse_fwd,
        flash_mla_with_kvcache,
        get_mla_metadata,
    )
else:

    class FlashMLASchedMeta:  # type: ignore[no-redef]
        pass

    flash_attn_varlen_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_sparse_fwd = _raise_flashmla_unavailable  # type: ignore[assignment]
    flash_mla_with_kvcache = _raise_flashmla_unavailable  # type: ignore[assignment]
    get_mla_metadata = _raise_flashmla_unavailable  # type: ignore[assignment]


def get_mla_metadata_dense_fp8(
114
115
116
    cache_seqlens: torch.Tensor,
    num_q_tokens_per_head_k: int,
    num_heads_k: int,
117
    num_heads_q : int = 16,
118
) -> tuple[torch.Tensor, torch.Tensor]:
119
120
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
121
122
123
    
    if current_platform.is_rocm():
        return flash_mla_cuda.flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(
124
125
126
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
127
            num_heads_q,
128
        )
129
130
    else:
        return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
131
132
133
134
            cache_seqlens,
            num_q_tokens_per_head_k,
            num_heads_k,
        )
135

zhuwenwen's avatar
zhuwenwen committed
136

137
def flash_mla_with_kvcache_fp8(
138
139
140
141
142
143
144
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
145
    softmax_scale: float | None = None,
146
    causal: bool = False,
147
148
    descale_q: torch.Tensor | None = None,
    descale_k: torch.Tensor | None = None,
149
) -> tuple[torch.Tensor, torch.Tensor]:
150
151
    if not _is_flashmla_available()[0]:
        _raise_flashmla_unavailable()
152
    if softmax_scale is None:
153
        softmax_scale = q.shape[-1] ** (-0.5)
154
155
    if current_platform.is_rocm():
        out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
156
157
            q,
            k_cache,
158
            None,
159
160
161
162
163
164
165
166
167
168
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            descale_q,
            descale_k,
        )
169
    else:
170
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
171
172
173
174
175
176
177
178
179
            q,
            k_cache,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
180
181
            descale_q,
            descale_k,
182
        )
183
184
185
    return out, softmax_lse


186
187
188
189
190
191
192
193
194
195
196
#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#