"vllm/vscode:/vscode.git/clone" did not exist on "2f308214c0ff6cfa849879c5beb884192714f429"
flashmla.py 7.04 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
    try:
        import vllm._flashmla_C  # noqa: F401
        _flashmla_C_AVAILABLE = True
    except ImportError:
        _flashmla_C_AVAILABLE = False
else:
    _flashmla_C_AVAILABLE = False

22
23
24
25
26
27
28
29
30
if current_platform.is_cuda():
    try:
        import vllm._flashmla_extension_C  # noqa: F401
        _flashmla_extension_C_AVAILABLE = True
    except ImportError:
        _flashmla_extension_C_AVAILABLE = False
else:
    _flashmla_extension_C_AVAILABLE = False

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    if not current_platform.is_cuda():
        return False, "FlashMLA is only supported on CUDA devices."
    if current_platform.get_device_capability()[0] != 9:
        return False, "FlashMLA is only supported on Hopper devices."
    if not _flashmla_C_AVAILABLE:
        return False, "vllm._flashmla_C is not available, likely was not "\
            "compiled due to insufficient nvcc version or a supported arch "\
            "(only sm90a currently) was not in the list of target arches to "\
            "compile for."
    return True, None


def get_mla_metadata(
49
50
51
52
53
54
        cache_seqlens: torch.Tensor,
        num_q_tokens_per_head_k: int,
        num_heads_k: int,
        num_heads_q: Optional[int] = None,
        is_fp8_kvcache: bool = False,
        topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
55
56
    """
    Arguments:
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    - cache_seqlens: (batch_size), dtype torch.int32.
    - num_q_tokens_per_head_k: 
            Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
    - num_heads_k: The number of k heads.
    - num_heads_q: 
            The number of q heads. 
            This argument is optional when sparse attention is not enabled
    - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
    - topk: If not None, sparse attention will be enabled, 
            and only tokens in the `indices` array 
            passed to `flash_mla_with_kvcache_sm90` will be attended to.

    Returns:
    - tile_scheduler_metadata: 
            (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
    - num_splits: (batch_size + 1), dtype torch.int32.
73
    """
74
75
76
    return torch.ops._flashmla_C.get_mla_decoding_metadata(
        cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
        is_fp8_kvcache, topk)
77
78
79
80
81
82
83
84
85
86
87
88


def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
89
90
    descale_q: Optional[torch.Tensor] = None,
    descale_k: Optional[torch.Tensor] = None,
91
92
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
93
94
95
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    - q: (batch_size, seq_len_q, num_heads_q, head_dim).
    - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
    - block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
    - cache_seqlens: (batch_size), torch.int32.
    - head_dim_v: Head dimension of v.
    - tile_scheduler_metadata: 
        (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, 
        returned by get_mla_metadata.
    - num_splits: 
        (batch_size + 1), torch.int32, returned by get_mla_metadata.
    - softmax_scale: float. 
        The scale of QK^T before applying softmax. 
        Default to 1 / sqrt(head_dim).
    - causal: bool. Whether to apply causal attention mask.
    - descale_q: (batch_size), 
        torch.float32. Descaling factors for Q, used for fp8 quantization.
    - descale_k: (batch_size), 
        torch.float32. Descaling factors for K, used for fp8 quantization.
    - is_fp8_kvcache: bool. 
        Whether the k_cache and v_cache are in fp8 format. 
        For the format of FP8 KV cache, please refer to README.md
    - indices: (batch_size, seq_len_q, topk), torch.int32. 
        If not None, sparse attention will be enabled, 
        and only tokens in the `indices` array will be attended to. 
        Invalid indices should be set to -1 or numbers >= total_seq_len_kv. 
        For details about how to set up `indices`, please refer to README.md.

    Returns:
    - out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
    - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
126
127
128
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1]**(-0.5)
129
130
131
132
133
134
135
136
137
138
    if indices is not None:
        # NOTE (zyongye): sparse attention is also causal
        # since it only attend to the tokens before
        # but here `causal` should not be specified
        assert not causal, \
            "causal must be `false` if sparse attention is enabled."
    assert (descale_q is None) == (
        descale_k is None
    ), "descale_q and descale_k should be both None or both not None"

139
    if indices is None and q.element_size() == 1:
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
            q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
            causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
    else:
        out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
            q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
            causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
            indices)
    return out, softmax_lse


def flash_mla_sparse_prefill(
    q: torch.Tensor,
    kv: torch.Tensor,
    indices: torch.Tensor,
    sm_scale: float,
    d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sparse attention prefill kernel

    Args:
    - q: [s_q, h_q, d_qk], bfloat16
    - kv: [s_kv, h_kv, d_qk], bfloat16
    - indices: [s_q, h_kv, topk], int32. 
        Invalid indices should be set to -1 or numbers >= s_kv
    - sm_scale: float
    - d_v: The dimension of value vectors. Can only be 512

    Returns:
    - (output, max_logits, lse)
        About the definition of output, 
        max_logits and lse, please refer to README.md
    - output: [s_q, h_q, d_v], bfloat16
    - max_logits:  [s_q, h_q], float
    - lse: [s_q, h_q], float, 2-based log-sum-exp
    """
    results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
                                                       sm_scale, d_v)
    return results
180
181
182
183
184
185
186
187
188
189
190
191
192


#
# TODO: Add fake functions
#
# @register_fake("_flashmla_C::get_mla_metadata")
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
#     return ....
#