cutlass_mla.py 10.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from typing import ClassVar, Optional
6
7
8
9

import torch

import vllm._custom_ops as ops
10
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
11
12
13
14
                                              is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
                                                   MLACommonImpl,
15
16
17
                                                   MLACommonMetadata,
                                                   MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
18
19
20
21

logger = init_logger(__name__)


22
23
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
    # enable full CUDA Graph support for decode-only capture
24
    cudagraph_support: ClassVar[
25
        AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
26
27


28
29
30
31
class CutlassMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
32
        return "CUTLASS_MLA"
33
34
35
36
37

    @staticmethod
    def get_impl_cls() -> type["CutlassMLAImpl"]:
        return CutlassMLAImpl

38
39
40
41
    @staticmethod
    def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
        return CutlassMLAMetadataBuilder

42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class SM100Workspace:

    def __init__(self, initial_workspace_size):
        self._workspace_buf = torch.empty(initial_workspace_size,
                                          device="cuda",
                                          dtype=torch.uint8)

        self._block_size = 128  # Forced to 128

        # Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
        # (assumes all devices are similar)
        properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
        self._sm_count = properties.multi_processor_count

    def get_buf(self):
        return self._workspace_buf

    def ensure_size(self, attn_metadata: MLACommonMetadata,
                    num_kv_splits: int):
        batch_size = attn_metadata.num_reqs
        max_seq_len = attn_metadata.max_query_len

        workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
            max_seq_len * self._block_size,
            batch_size,
            self._sm_count,
            num_kv_splits=num_kv_splits)

        if self._workspace_buf.shape[0] < workspace_size:
            self._workspace_buf.resize_(workspace_size)


g_sm100_workspace = SM100Workspace(128 * 1024 * 1024)  # 128MB


78
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
79
    can_return_lse_for_decode: bool = True
80
81
82
83
84
85
86
87
88
89
90
91

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[list[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            logits_soft_cap: Optional[float],
            attn_type: str,
92
            kv_sharing_target_layer_name: Optional[str],
93
94
95
96
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
97
                         logits_soft_cap, attn_type,
98
                         kv_sharing_target_layer_name, **mla_args)
99

100
        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
101
102
103
        if any(unsupported_features):
            raise NotImplementedError(
                "CutlassMLAImpl does not support one of the following: "
104
                "alibi_slopes, sliding_window, logits_soft_cap")
105
106
107
108
109
110
111

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "CutlassMLAImpl")

112
113
114
        self._use_old_cutlass_mla = False
        force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None)
        if force_old_cutlass:
115
            logger.warning_once("Forcing old cutlass mla kernel")
116
117
118
119
120
121
122
            self._use_old_cutlass_mla = True

        # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
        #       issues. In case the code hangs, use:
        #       FORCE_NUM_KV_SPLITS=1
        force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
        if force_num_kv_splits:
123
124
            logger.warning_once("Forcing num_kv_splits to %d",
                                int(force_num_kv_splits))
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
            self._num_kv_splits = int(force_num_kv_splits)
        else:
            self._num_kv_splits = -1  # => Auto-detect

        # Share workspace buffer across all executions
        self._workspace = g_sm100_workspace

    def _sm100_cutlass_mla_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        seq_lens: torch.Tensor,
        page_table: torch.Tensor,
        workspace: torch.Tensor,
        sm_scale: float,
        num_kv_splits: int,
142
    ) -> tuple[torch.Tensor, torch.Tensor]:
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        assert (q_nope.ndim == 3
                ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
        assert (
            q_pe.ndim == 3), f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
        assert (
            kv_c_and_k_pe_cache.ndim == 3
        ), "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
            kv_c_and_k_pe_cache.ndim)

        B_q, H, D_q_nope = q_nope.shape
        B_q_2, H_2, D_q_pe = q_pe.shape
        assert (B_q == B_q_2) and (H == H_2)

        _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape

        D_latent = 512
        D_rope = 64
        assert D_q_nope == D_latent
        assert D_q_pe == D_rope
        assert D_ckv == D_latent + D_rope

        MAX_HEADS = 128
        assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
        if H < MAX_HEADS:
            q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
            q_nope_padded[:, :H] = q_nope
            q_nope = q_nope_padded

            q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
            q_pe_padded[:, :H] = q_pe
            q_pe = q_pe_padded

        assert len(page_table.shape) == 2
        B_block_table, block_num = page_table.shape
        assert B_block_table == B_q
        assert (block_num
                > 0), f"block num must be greater than 0, got {block_num}"
        assert block_num % (128 / PAGE_SIZE) == 0

        assert q_nope.dtype in (
183
184
185
            torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
                f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got "
                f"{q_nope.dtype}.")
186
187
188
189
190
191
192
193
        assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
        assert (
            seq_lens.dtype == torch.int32
        ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
        assert (
            page_table.dtype == torch.int32
        ), f"page_table.dtype needs to be int32 but got {page_table.dtype}."

194
195
196
        dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
                 else q_nope.dtype)
        out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
197
198
199
        lse = (torch.empty(
            (B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
               if self.need_to_return_lse_for_decode else torch.Tensor())
200
201
202

        ops.sm100_cutlass_mla_decode(
            out,
203
            lse,
204
205
206
207
208
209
210
211
212
            q_nope,
            q_pe,
            kv_c_and_k_pe_cache,
            seq_lens,
            page_table,
            workspace,
            sm_scale,
            num_kv_splits,
        )
213
214
215
        returned_lse = lse[:, :H].contiguous(
        ) if self.need_to_return_lse_for_decode else lse
        return out[:, :H].contiguous(), returned_lse
216
217
218
219
220
221
222

    def _sm100_forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
223
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
224
225
226
227
228
229
230
231
232
233
234
235
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

        # Adjust workspace size (if necessary)
        self._workspace.ensure_size(attn_metadata, self._num_kv_splits)

        # Run MLA
        # Clone q_nope and q_pe to make sure strides computation is correct.
        # TODO: Check if we really need it
        q_nope = q_nope.clone()
        q_pe = q_pe.clone()

236
237
238
239
240
241
242
243
244
245
        o, lse = self._sm100_cutlass_mla_decode(
            q_nope,
            q_pe,
            kv_c_and_k_pe_cache,
            attn_metadata.decode.seq_lens,
            attn_metadata.decode.block_table,
            self._workspace.get_buf(),
            self.scale,
            self._num_kv_splits,
        )
246

247
        return o, (lse if self.need_to_return_lse_for_decode else None)
248
249
250
251

    # TODO: Currently we leave it here only for backup in case something is
    #       wrong with the new SM100 CUTLASS MLA kernel
    def _old_forward_decode(
252
253
254
255
256
257
258
259
260
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

261
262
263
        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")
264
265
266
267
268
269
270
271
272
273
274

        B = q_nope.shape[0]

        o = torch.empty((B, self.num_heads, self.kv_lora_rank),
                        dtype=q_nope.dtype,
                        device=q_nope.device)

        # Run MLA
        # Clone q_nope and q_pe to make sure strides computation is correct.
        q_nope = q_nope.clone()
        q_pe = q_pe.clone()
275

276
277
278
279
        ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
                               attn_metadata.decode.seq_lens,
                               attn_metadata.decode.block_table, self.scale)

280
        return o
281
282
283

    def _forward_decode(
        self,
284
        q: torch.Tensor,
285
286
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
287
        layer: AttentionLayer,
288
289
290
291
292
293
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        if type(q) is tuple:
            q_nope, q_pe = q
        else:
            q_nope, q_pe = torch.split(
                q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
294
295
296
297
        if self._use_old_cutlass_mla:
            # TODO: Remove the old cutlass MLA kernel after more extensive
            #       testing
            return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
298
                                            attn_metadata), None
299
300

        return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
301
                                          attn_metadata)