flashmla.py 7.65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

22
23
24
25
26
27
28
29
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
    _flashmla_extension_C_AVAILABLE = False
zhuwenwen's avatar
zhuwenwen committed
30
31
32
33
    
if current_platform.is_rocm():
    import flash_mla_cuda
    _flashmla_C_AVAILABLE = True
34

35
36
37
38
39

def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
zhuwenwen's avatar
zhuwenwen committed
40
41
    if not (current_platform.is_cuda() or current_platform.is_rocm()):
        return False, "FlashMLA is supported on CUDA and ROCM devices."
42
43
44
45
46
47
48
49
50
51
52
    if current_platform.get_device_capability()[0] != 9:
        return False, "FlashMLA is only supported on Hopper devices."
    if not _flashmla_C_AVAILABLE:
        return False, "vllm._flashmla_C is not available, likely was not "\
            "compiled due to insufficient nvcc version or a supported arch "\
            "(only sm90a currently) was not in the list of target arches to "\
            "compile for."
    return True, None


def get_mla_metadata(
53
54
55
56
57
58
        cache_seqlens: torch.Tensor,
        num_q_tokens_per_head_k: int,
        num_heads_k: int,
        num_heads_q: Optional[int] = None,
        is_fp8_kvcache: bool = False,
        topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
59
60
    """
    Arguments:
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    - cache_seqlens: (batch_size), dtype torch.int32.
    - num_q_tokens_per_head_k: 
            Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
    - num_heads_k: The number of k heads.
    - num_heads_q: 
            The number of q heads. 
            This argument is optional when sparse attention is not enabled
    - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
    - topk: If not None, sparse attention will be enabled, 
            and only tokens in the `indices` array 
            passed to `flash_mla_with_kvcache_sm90` will be attended to.

    Returns:
    - tile_scheduler_metadata: 
            (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    - num_splits: (batch_size + 1), dtype torch.int32.
77
    """
zhuwenwen's avatar
zhuwenwen committed
78
79
80
81
82
83
84
        
    if current_platform.is_rocm():
        return flash_mla_cuda.get_mla_metadata(cache_seqlens,
                                num_q_tokens_per_head_k,
                                num_heads_k)
    else:
        return torch.ops._flashmla_C.get_mla_decoding_metadata(
85
86
        cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
        is_fp8_kvcache, topk)
87
88


zhuwenwen's avatar
zhuwenwen committed
89

90
91
92
93
94
95
96
97
98
99
def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
100
101
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
102
103
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
104
105
106
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    - q: (batch_size, seq_len_q, num_heads_q, head_dim).
    - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
    - block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
    - cache_seqlens: (batch_size), torch.int32.
    - head_dim_v: Head dimension of v.
    - tile_scheduler_metadata: 
        (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, 
        returned by get_mla_metadata.
    - num_splits: 
        (batch_size + 1), torch.int32, returned by get_mla_metadata.
    - softmax_scale: float. 
        The scale of QK^T before applying softmax. 
        Default to 1 / sqrt(head_dim).
    - causal: bool. Whether to apply causal attention mask.
    - descale_q: (batch_size), 
        torch.float32. Descaling factors for Q, used for fp8 quantization.
    - descale_k: (batch_size), 
        torch.float32. Descaling factors for K, used for fp8 quantization.
    - is_fp8_kvcache: bool. 
        Whether the k_cache and v_cache are in fp8 format. 
        For the format of FP8 KV cache, please refer to README.md
    - indices: (batch_size, seq_len_q, topk), torch.int32. 
        If not None, sparse attention will be enabled, 
        and only tokens in the `indices` array will be attended to. 
        Invalid indices should be set to -1 or numbers >= total_seq_len_kv. 
        For details about how to set up `indices`, please refer to README.md.

    Returns:
    - out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
    - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
137
138
139
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1]**(-0.5)
140
141
142
143
144
145
146
147
148
149
    if indices is not None:
        # NOTE (zyongye): sparse attention is also causal
        # since it only attend to the tokens before
        # but here `causal` should not be specified
        assert not causal, \
            "causal must be `false` if sparse attention is enabled."
    assert (descale_q is None) == (
        descale_k is None
    ), "descale_q and descale_k should be both None or both not None"

150
    if indices is None and q.element_size() == 1:
151
152
153
154
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
            q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
            causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
    else:
zhuwenwen's avatar
zhuwenwen committed
155
156
157
158
159
160
161
162
163
        if current_platform.is_rocm():
            out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
                q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata,
                num_splits, softmax_scale, causal)
        else:
            out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
                q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
                causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
                indices)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    return out, softmax_lse


def flash_mla_sparse_prefill(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sparse attention prefill kernel

    Args:
    - q: [s_q, h_q, d_qk], bfloat16
    - kv: [s_kv, h_kv, d_qk], bfloat16
    - indices: [s_q, h_kv, topk], int32. 
        Invalid indices should be set to -1 or numbers >= s_kv
    - sm_scale: float
    - d_v: The dimension of value vectors. Can only be 512

    Returns:
    - (output, max_logits, lse)
        About the definition of output, 
        max_logits and lse, please refer to README.md
    - output: [s_q, h_q, d_v], bfloat16
    - max_logits:  [s_q, h_q], float
    - lse: [s_q, h_q], float, 2-based log-sum-exp
    """
    results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
                                                       sm_scale, d_v)
    return results
196
197
198
199
200
201
202
203
204
205
206
207
208


#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#