flashmla.py 7.21 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
4
from typing import Optional
5
6
7
8
9
10
11
12
13
14
15

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
16

17
18
19
20
21
22
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

23
24
25
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
26

27
28
29
30
31
32
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
    _flashmla_extension_C_AVAILABLE = False

33

34
def is_flashmla_supported() -> tuple[bool, Optional[str]]:
35
36
37
38
39
40
41
42
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    if not current_platform.is_cuda():
        return False, "FlashMLA is only supported on CUDA devices."
    if current_platform.get_device_capability()[0] != 9:
        return False, "FlashMLA is only supported on Hopper devices."
    if not _flashmla_C_AVAILABLE:
43
44
45
46
47
48
49
        return (
            False,
            "vllm._flashmla_C is not available, likely was not "
            "compiled due to insufficient nvcc version or a supported arch "
            "(only sm90a currently) was not in the list of target arches to "
            "compile for.",
        )
50
51
52
53
    return True, None


def get_mla_metadata(
54
55
56
57
58
59
    cache_seqlens: torch.Tensor,
    num_q_tokens_per_head_k: int,
    num_heads_k: int,
    num_heads_q: Optional[int] = None,
    is_fp8_kvcache: bool = False,
    topk: Optional[int] = None,
60
) -> tuple[torch.Tensor, torch.Tensor]:
61
62
    """
    Arguments:
63
    - cache_seqlens: (batch_size), dtype torch.int32.
64
    - num_q_tokens_per_head_k:
65
66
            Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
    - num_heads_k: The number of k heads.
67
68
    - num_heads_q:
            The number of q heads.
69
70
            This argument is optional when sparse attention is not enabled
    - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
71
72
    - topk: If not None, sparse attention will be enabled,
            and only tokens in the `indices` array
73
74
75
            passed to `flash_mla_with_kvcache_sm90` will be attended to.

    Returns:
76
    - tile_scheduler_metadata:
77
78
            (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    - num_splits: (batch_size + 1), dtype torch.int32.
79
    """
80
    return torch.ops._flashmla_C.get_mla_decoding_metadata(
81
82
83
84
85
86
87
        cache_seqlens,
        num_q_tokens_per_head_k,
        num_heads_k,
        num_heads_q,
        is_fp8_kvcache,
        topk,
    )
88
89
90
91
92
93
94
95
96
97
98
99


def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
100
101
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
102
103
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
104
) -> tuple[torch.Tensor, torch.Tensor]:
105
106
    """
    Arguments:
107
108
109
110
111
    - q: (batch_size, seq_len_q, num_heads_q, head_dim).
    - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
    - block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
    - cache_seqlens: (batch_size), torch.int32.
    - head_dim_v: Head dimension of v.
112
113
    - tile_scheduler_metadata:
        (num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
114
        returned by get_mla_metadata.
115
    - num_splits:
116
        (batch_size + 1), torch.int32, returned by get_mla_metadata.
117
118
    - softmax_scale: float.
        The scale of QK^T before applying softmax.
119
120
        Default to 1 / sqrt(head_dim).
    - causal: bool. Whether to apply causal attention mask.
121
    - descale_q: (batch_size),
122
        torch.float32. Descaling factors for Q, used for fp8 quantization.
123
    - descale_k: (batch_size),
124
        torch.float32. Descaling factors for K, used for fp8 quantization.
125
126
    - is_fp8_kvcache: bool.
        Whether the k_cache and v_cache are in fp8 format.
127
        For the format of FP8 KV cache, please refer to README.md
128
129
130
131
    - indices: (batch_size, seq_len_q, topk), torch.int32.
        If not None, sparse attention will be enabled,
        and only tokens in the `indices` array will be attended to.
        Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
132
133
134
135
136
        For details about how to set up `indices`, please refer to README.md.

    Returns:
    - out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
    - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
137
138
    """
    if softmax_scale is None:
139
        softmax_scale = q.shape[-1] ** (-0.5)
140
141
142
143
    if indices is not None:
        # NOTE (zyongye): sparse attention is also causal
        # since it only attend to the tokens before
        # but here `causal` should not be specified
144
145
146
147
        assert not causal, "causal must be `false` if sparse attention is enabled."
    assert (descale_q is None) == (descale_k is None), (
        "descale_q and descale_k should be both None or both not None"
    )
148

149
    if indices is None and q.element_size() == 1:
150
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
151
152
153
154
155
156
157
158
159
160
161
162
            q,
            k_cache,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            descale_q,
            descale_k,
        )
163
164
    else:
        out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
165
166
167
168
169
170
171
172
173
174
175
176
            q,
            k_cache,
            head_dim_v,
            cache_seqlens,
            block_table,
            softmax_scale,
            causal,
            tile_scheduler_metadata,
            num_splits,
            is_fp8_kvcache,
            indices,
        )
177
178
179
180
181
182
183
184
185
    return out, softmax_lse


def flash_mla_sparse_prefill(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
186
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
188
189
190
191
192
    """
    Sparse attention prefill kernel

    Args:
    - q: [s_q, h_q, d_qk], bfloat16
    - kv: [s_kv, h_kv, d_qk], bfloat16
193
    - indices: [s_q, h_kv, topk], int32.
194
195
196
197
198
199
        Invalid indices should be set to -1 or numbers >= s_kv
    - sm_scale: float
    - d_v: The dimension of value vectors. Can only be 512

    Returns:
    - (output, max_logits, lse)
200
        About the definition of output,
201
202
203
204
205
        max_logits and lse, please refer to README.md
    - output: [s_q, h_q, d_v], bfloat16
    - max_logits:  [s_q, h_q], float
    - lse: [s_q, h_q], float, 2-based log-sum-exp
    """
206
    results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
207
    return results
208
209
210
211
212
213
214
215
216
217
218
219
220


#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#