flashmla.py 7.98 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

22

zhuwenwen's avatar
zhuwenwen committed
23
24
25
if current_platform.is_rocm():
    import flash_mla_cuda
    _flashmla_C_AVAILABLE = True
26

27
28
29
30
31
32
33
34
35
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
    _flashmla_extension_C_AVAILABLE = False

36
37
38
39
40

def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
zhuwenwen's avatar
zhuwenwen committed
41
42
    if not (current_platform.is_cuda() or current_platform.is_rocm()):
        return False, "FlashMLA is supported on CUDA and ROCM devices."
43
44
45
46
47
48
49
50
51
52
53
    if current_platform.get_device_capability()[0] != 9:
        return False, "FlashMLA is only supported on Hopper devices."
    if not _flashmla_C_AVAILABLE:
        return False, "vllm._flashmla_C is not available, likely was not "\
            "compiled due to insufficient nvcc version or a supported arch "\
            "(only sm90a currently) was not in the list of target arches to "\
            "compile for."
    return True, None


def get_mla_metadata(
54
55
56
57
58
59
        cache_seqlens: torch.Tensor,
        num_q_tokens_per_head_k: int,
        num_heads_k: int,
        num_heads_q: Optional[int] = None,
        is_fp8_kvcache: bool = False,
        topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
60
61
    """
    Arguments:
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    - cache_seqlens: (batch_size), dtype torch.int32.
    - num_q_tokens_per_head_k: 
            Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
    - num_heads_k: The number of k heads.
    - num_heads_q: 
            The number of q heads. 
            This argument is optional when sparse attention is not enabled
    - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
    - topk: If not None, sparse attention will be enabled, 
            and only tokens in the `indices` array 
            passed to `flash_mla_with_kvcache_sm90` will be attended to.

    Returns:
    - tile_scheduler_metadata: 
            (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    - num_splits: (batch_size + 1), dtype torch.int32.
78
    """
zhuwenwen's avatar
zhuwenwen committed
79
80
    if current_platform.is_rocm():
        return flash_mla_cuda.get_mla_metadata(cache_seqlens,
81
                                num_q_tokens_per_head_k,
zhuwenwen's avatar
zhuwenwen committed
82
83
                                num_heads_k)
    else:
84
85
86
        return torch.ops._flashmla_C.get_mla_decoding_metadata(
            cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
            is_fp8_kvcache, topk)
87
88
89
90
91
92
93
94
95
96
97
98


def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
99
100
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
101
102
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
103
104
105
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    - q: (batch_size, seq_len_q, num_heads_q, head_dim).
    - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
    - block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
    - cache_seqlens: (batch_size), torch.int32.
    - head_dim_v: Head dimension of v.
    - tile_scheduler_metadata: 
        (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, 
        returned by get_mla_metadata.
    - num_splits: 
        (batch_size + 1), torch.int32, returned by get_mla_metadata.
    - softmax_scale: float. 
        The scale of QK^T before applying softmax. 
        Default to 1 / sqrt(head_dim).
    - causal: bool. Whether to apply causal attention mask.
    - descale_q: (batch_size), 
        torch.float32. Descaling factors for Q, used for fp8 quantization.
    - descale_k: (batch_size), 
        torch.float32. Descaling factors for K, used for fp8 quantization.
    - is_fp8_kvcache: bool. 
        Whether the k_cache and v_cache are in fp8 format. 
        For the format of FP8 KV cache, please refer to README.md
    - indices: (batch_size, seq_len_q, topk), torch.int32. 
        If not None, sparse attention will be enabled, 
        and only tokens in the `indices` array will be attended to. 
        Invalid indices should be set to -1 or numbers >= total_seq_len_kv. 
        For details about how to set up `indices`, please refer to README.md.

    Returns:
    - out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
    - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
136
137
138
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1]**(-0.5)
139

140
141
142
143
144
145
146
147
148
149
    if indices is not None:
        # NOTE (zyongye): sparse attention is also causal
        # since it only attend to the tokens before
        # but here `causal` should not be specified
        assert not causal, \
            "causal must be `false` if sparse attention is enabled."
    assert (descale_q is None) == (
        descale_k is None
    ), "descale_q and descale_k should be both None or both not None"

150
    if indices is None and q.element_size() == 1:
151
152
153
154
155
156
157
158
159
        # TODO @yangql
        if current_platform.is_rocm():
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
                q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale,
                causal, tile_scheduler_metadata, num_splits, descale_k, "fp8_e4m3")
        else:
            out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
                q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
                causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
zhuwenwen's avatar
zhuwenwen committed
160
    else:
161
162
163
164
165
166
167
168
169
        if current_platform.is_rocm():
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
                q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale,
                causal, tile_scheduler_metadata, num_splits)
        else:
            out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
                q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
                causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
                indices)
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    return out, softmax_lse


def flash_mla_sparse_prefill(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sparse attention prefill kernel

    Args:
    - q: [s_q, h_q, d_qk], bfloat16
    - kv: [s_kv, h_kv, d_qk], bfloat16
    - indices: [s_q, h_kv, topk], int32. 
        Invalid indices should be set to -1 or numbers >= s_kv
    - sm_scale: float
    - d_v: The dimension of value vectors. Can only be 512

    Returns:
    - (output, max_logits, lse)
        About the definition of output, 
        max_logits and lse, please refer to README.md
    - output: [s_q, h_q, d_v], bfloat16
    - max_logits:  [s_q, h_q], float
    - lse: [s_q, h_q], float, 2-based log-sum-exp
    """
    results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
                                                       sm_scale, d_v)
    return results
202
203
204
205
206
207
208
209
210
211
212
213


#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
zhuwenwen's avatar
zhuwenwen committed
214
#