cutlass_mla.py 9.38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import ClassVar
6
7
8
9

import torch

import vllm._custom_ops as ops
10
11
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
12
13
14
15
16
17
from vllm.model_executor.layers.attention.mla_attention import (
    MLACommonBackend,
    MLACommonImpl,
    MLACommonMetadata,
    MLACommonMetadataBuilder,
)
18
from vllm.platforms.interface import DeviceCapability
19
from vllm.utils.platform_utils import num_compute_units
20
from vllm.v1.attention.backend import (
21
    AttentionCGSupport,
22
23
    AttentionLayer,
    AttentionType,
24
    MultipleOf,
25
26
    is_quantized_kv_cache,
)
27
28
29
30

logger = init_logger(__name__)


31
32
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
    # enable full CUDA Graph support for decode-only capture
33
    _cudagraph_support: ClassVar[AttentionCGSupport] = (
34
35
        AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
    )
36
37


38
class CutlassMLABackend(MLACommonBackend):
39
40
41
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
42
        "float16",
43
        "bfloat16",
44
45
46
47
        "fp8",
        "fp8_e4m3",
    ]

48
49
50
51
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [128]

52
53
    @staticmethod
    def get_name() -> str:
54
        return "CUTLASS_MLA"
55
56
57
58
59

    @staticmethod
    def get_impl_cls() -> type["CutlassMLAImpl"]:
        return CutlassMLAImpl

60
61
62
63
    @staticmethod
    def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
        return CutlassMLAMetadataBuilder

64
65
66
    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability.major == 10
67

68

69
70
class SM100Workspace:
    def __init__(self, initial_workspace_size):
71
72
73
        self._workspace_buf = torch.empty(
            initial_workspace_size, device="cuda", dtype=torch.uint8
        )
74
75
76
77
78

        self._block_size = 128  # Forced to 128

        # Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
        # (assumes all devices are similar)
79
        self._sm_count = num_compute_units(0)
80
81
82
83

    def get_buf(self):
        return self._workspace_buf

84
    def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int):
85
86
87
88
89
90
91
        batch_size = attn_metadata.num_reqs
        max_seq_len = attn_metadata.max_query_len

        workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
            max_seq_len * self._block_size,
            batch_size,
            self._sm_count,
92
93
            num_kv_splits=num_kv_splits,
        )
94
95
96
97
98
99
100

        if self._workspace_buf.shape[0] < workspace_size:
            self._workspace_buf.resize_(workspace_size)


g_sm100_workspace = SM100Workspace(128 * 1024 * 1024)  # 128MB

101
102
MAX_HEADS = 128

103

104
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
105
    can_return_lse_for_decode: bool = True
106
107

    def __init__(
108
109
110
111
112
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
113
114
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
115
        kv_cache_dtype: str,
116
        logits_soft_cap: float | None,
117
        attn_type: str,
118
        kv_sharing_target_layer_name: str | None,
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        # MLA Specific Arguments
        **mla_args,
    ) -> None:
        super().__init__(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            q_pad_num_heads=MAX_HEADS,
            **mla_args,
        )
136

137
        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
138
139
140
        if any(unsupported_features):
            raise NotImplementedError(
                "CutlassMLAImpl does not support one of the following: "
141
142
                "alibi_slopes, sliding_window, logits_soft_cap"
            )
143
144

        if attn_type != AttentionType.DECODER:
145
146
147
148
149
150
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "CutlassMLAImpl"
            )
151

152
153
154
155
156
        # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
        #       issues. In case the code hangs, use:
        #       FORCE_NUM_KV_SPLITS=1
        force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
        if force_num_kv_splits:
157
            logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
158
159
160
161
162
163
164
            self._num_kv_splits = int(force_num_kv_splits)
        else:
            self._num_kv_splits = -1  # => Auto-detect

        # Share workspace buffer across all executions
        self._workspace = g_sm100_workspace

165
166
167
168
169
        # Pre-allocated output buffer, lazily sized on first call.
        # Zero-init once to prevent NaN in padding slots (seq_lens=0)
        # from contaminating downstream per-tensor reductions.
        self._decode_out: torch.Tensor | None = None

170
171
172
173
174
175
176
177
178
179
    def _sm100_cutlass_mla_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        seq_lens: torch.Tensor,
        page_table: torch.Tensor,
        workspace: torch.Tensor,
        sm_scale: float,
        num_kv_splits: int,
180
    ) -> tuple[torch.Tensor, torch.Tensor]:
181
182
183
184
185
186
187
        assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
        assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
        assert kv_c_and_k_pe_cache.ndim == 3, (
            "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
                kv_c_and_k_pe_cache.ndim
            )
        )
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

        B_q, H, D_q_nope = q_nope.shape
        B_q_2, H_2, D_q_pe = q_pe.shape
        assert (B_q == B_q_2) and (H == H_2)

        _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape

        D_latent = 512
        D_rope = 64
        assert D_q_nope == D_latent
        assert D_q_pe == D_rope
        assert D_ckv == D_latent + D_rope

        MAX_HEADS = 128
        assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"

        assert len(page_table.shape) == 2
        B_block_table, block_num = page_table.shape
        assert B_block_table == B_q
207
        assert block_num > 0, f"block num must be greater than 0, got {block_num}"
208
209
        assert block_num % (128 / PAGE_SIZE) == 0

210
211
212
        assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
            f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}."
        )
213
        assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
214
215
216
217
218
219
220
221
222
223
224
225
        assert seq_lens.dtype == torch.int32, (
            f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
        )
        assert page_table.dtype == torch.int32, (
            f"page_table.dtype needs to be int32 but got {page_table.dtype}."
        )

        dtype = (
            torch.bfloat16
            if is_quantized_kv_cache(self.kv_cache_dtype)
            else q_nope.dtype
        )
226
227
228
229
230
231
232
233
234
        # Reuse pre-allocated zero-init output buffer to avoid a memset
        # kernel on every CUDA graph replay.
        if (
            self._decode_out is None
            or self._decode_out.shape[0] < B_q
            or self._decode_out.dtype != dtype
        ):
            self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
        out = self._decode_out[:B_q]
235
236
237
238
239
        lse = (
            torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
            if self.need_to_return_lse_for_decode
            else torch.Tensor()
        )
240
241
242

        ops.sm100_cutlass_mla_decode(
            out,
243
            lse,
244
245
246
247
248
249
250
251
252
            q_nope,
            q_pe,
            kv_c_and_k_pe_cache,
            seq_lens,
            page_table,
            workspace,
            sm_scale,
            num_kv_splits,
        )
253
254

        if H < MAX_HEADS:
255
256
            # Extract the subsets of the outputs
            lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
257
258
            out = out[:, :H]

259
        return out, lse
260

261
    def forward_mqa(
262
        self,
263
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
264
265
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
266
        layer: AttentionLayer,
267
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
268
269
270
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

271
272
273
274
275
        if layer._q_scale_float != 1.0 or layer._k_scale_float != 1.0:
            raise NotImplementedError(
                "CutlassMLAImpl does not support scaling for q and kv_latent yet"
            )

276
277
278
279
        if type(q) is tuple:
            q_nope, q_pe = q
        else:
            q_nope, q_pe = torch.split(
280
281
                q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
282

283
284
285
286
        # Adjust workspace size (if necessary)
        self._workspace.ensure_size(attn_metadata, self._num_kv_splits)

        # Run MLA
287
288
289
290
291
292
293
294
295
296
        o, lse = self._sm100_cutlass_mla_decode(
            q_nope,
            q_pe,
            kv_c_and_k_pe_cache,
            attn_metadata.decode.seq_lens,
            attn_metadata.decode.block_table,
            self._workspace.get_buf(),
            self.scale,
            self._num_kv_splits,
        )
297

298
        return o, (lse if self.need_to_return_lse_for_decode else None)