common.py 36.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
# MLA Common Components

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
This file implements common components for MLA implementations.

First we define:

Sq      as Q sequence length
Skv     as KV sequence length

MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").

NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.

Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

Deepseek's MLA attention works the following way:
27
* Use a single latent vector to represent the per-token entry of the KV cache. 
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.

Below is example of both paths assuming batchsize = 1

## More Extent Definitions:

C           Context length, `Skv - Sq`
H           hidden size
N           number of attention heads
Lq          latent dimension for Q              1536 in DSV3
Lkv         latent dimension for K/V            512 in DSV3
P           nope dimension, no rope.            128 in DSV3
R           rope dimension, goes through rope.  64 in DSV3
V           V head dim.                         128 in DSV3

## Vector/Matrix Definitions

h_t         hidden states (input to attention)  shape [Sq, H]
q_c         latent/compressed Q                 shape [Sq, Lq]
q_nope      uncompressed Q (no-rope)            shape [Sq, N, P]
q_pe        uncompressed Q (rope)               shape [Sq, N, R]
kv_c        latent/compressed KV                shape [Skv, Lkv]
k_pe        decoupled k position embeddings     shape [Skv, R]
new_kv_c    new kv_c from current iter          shape [Sq, Lkv]
new_k_pe    new k_pe from current iter          shape [Sq, R]
cache_kv_c  cached k_c from previous iters      shape [C, Lkv]
cache_k_pe  cached k_pe from previous iters     shape [C, R]
W_DQ        project h_t to q_c                  shape [H, Lq]
W_UQ        project q_c to q_nope               shape [Lq, N * P]
W_QR        project q_c to q_pe                 shape [Lq, N * R]
W_DKV       project h_t to kv_c                 shape [H, Lkv]
60
61
62
W_UK        project kv_c to k_nope              shape [Lkv, N, P]
W_KR        project h_t to k_pe                 shape [H, R]
W_UV        project kv_c to v                   shape [Lkv, N, V]
63
64
65
66
67
68
69
70
71
72
73
74
W_O         project v to h_t                    shape [N * V, H]


## Compute Friendly Approach (i.e. "_forward_prefill"):

q_c      = h_t @ W_DQ
q_nope   = (q_c @ W_UQ).view(Sq, N, P)
q_pe     = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c     = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe     = torch.cat([new_k_pe, cache_k_pe], dim=0)
75
76
k_nope   = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v        = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
77
78
79
80
81
82
83
84

// MHA with QK headdim = P + R
//           V headdim = V
//      spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
    torch.cat([q_nope, q_pe], dim=-1),
    torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
    v
85
) 
86
87
88
return spda_o @ W_O

NOTE: in the actual code,
89
90
    `kv_b_proj` is [W_UK; W_UV] concatenated per head
    `q_b_proj` is [W_UQ; W_QR] concatenated per head
91
92
93
94
95
96
97
    `out_proj` is W_O


## Data-Movement Friendly Approach (i.e. "_forward_decode"):

Runtime
q_c      = h_t @ W_DQ
98
99
q_nope   = (q_c @ W_UQ).view(-1, N, P)
ql_nope  = einsum("snh,lnh->snl", q, W_UK)
100
101
102
103
104
105
106
107
108
109
110
111
q_pe     = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c     = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe     = torch.cat([new_k_pe, cache_k_pe], dim=0)

// MQA with QK headdim = Lkv + R
//           V headdim = Lkv
//      spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
//       but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
112
    torch.cat([ql_nope, q_pe], dim=-1),
113
114
115
    torch.cat([kv_c, k_pe], dim=-1),
    kv_c
)
116
117
118

o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
119
120
121
122


## Chunked Prefill

123
124
For chunked prefill we want to use the compute friendly algorithm. We are 
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to 
125
126
127
128
129
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.

However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`

130
131
To mitigate this, we chunk the computation of attention with respect to the 
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a 
132
133
134
135
fixed workspace size.

The chunked prefill approach is as follows:

136
MCC        Max chunk of context to process per iter, computed dynamically, 
137
138
139
140
141
142
143
           used to bound the memory usage

q_c        = h_t @ W_DQ
q_nope     = (q_c @ W_UQ).view(Sq, N, P)
q_pe       = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c   = h_t @ W_DKV
new_k_pe   = RoPE(h_t @ W_KR)
144
145
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v      = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
146
147
148
149
150
151
152
153
154
155
156
157

// MHA between queries and new KV
//     with QK headdim = P + R
//           V headdim = V
//    curr_o   shape [Sq, N, V]
//    curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
    torch.cat([q_nope, q_pe], dim=-1),
    torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
    new_v,
    casual=True,
    return_softmax_lse=True
158
) 
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
    chunk_start  = chunk_idx * MCC
    chunk_end    = min(chunk_start + MCC, C)
    Sc           = chunk_end - chunk_start
    cache_kv_c_chunk   = cache_kv_c[chunk_start:chunk_end]
    cache_k_pe_chunk   = cache_k_pe[chunk_start:chunk_end]
    cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
    cache_v_chunk      = (cache_kv_c_chunk @ W_UV).view(-1, N, V)

    chunk_o, chunk_lse = scaled_dot_product_attention(
        torch.cat([q_nope, q_pe], dim=-1),
        torch.cat([cache_k_nope_chunk,
                   cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
                   dim=-1),
        cache_v_chunk,
        casual=False,
        return_softmax_lse=True
    )

    curr_o, curr_lse = merge_attn_states(
        suffix_output=curr_o,
        suffix_lse=curr_lse,
        prefix_output=chunk_o,
        prefix_lse=chunk_lse,
    )

return curr_o @ W_O
"""

import functools
from abc import abstractmethod
from dataclasses import dataclass
193
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
194
195

import torch
zhuwenwen's avatar
zhuwenwen committed
196
import os
197
198

from vllm import _custom_ops as ops
199
from vllm import envs
200
201
202
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
                                              AttentionMetadata,
                                              MLAAttentionImpl)
203
from vllm.attention.backends.utils import get_mla_dims
204
from vllm.attention.ops.merge_attn_states import merge_attn_states
205
from vllm.attention.utils.fa_utils import get_flash_attn_version
206
207
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
208
                                               LinearBase,
209
                                               UnquantizedLinearMethod)
Simon Mo's avatar
Simon Mo committed
210
from vllm.platforms import current_platform
211
from vllm.utils import cdiv, round_down
212
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
213
214
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
215
216
217

try:
    from vllm.vllm_flash_attn import flash_attn_varlen_func
218
    is_vllm_fa = True
219
220
221
except ImportError:
    # For rocm use upstream flash attention
    from flash_attn import flash_attn_varlen_func
222
    is_vllm_fa = False
223
224

if TYPE_CHECKING:
225
    from vllm.v1.core.sched.output import SchedulerOutput
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    from vllm.v1.worker.gpu_input_batch import InputBatch
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner

logger = init_logger(__name__)


class MLACommonBackend(AttentionBackend):

    accept_output_buffer: bool = True

    @staticmethod
    def get_name() -> str:
        return "TRITON_MLA_VLLM_V1"

    @staticmethod
241
    def get_metadata_cls() -> type["AttentionMetadata"]:
242
243
244
        return MLACommonMetadata

    @staticmethod
245
    def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
246
247
248
249
250
251
252
253
        return MLACommonMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
254
    ) -> tuple[int, ...]:
255
256
257
        return (num_blocks, block_size, head_size)

    @staticmethod
258
    def get_supported_head_sizes() -> list[int]:
259
260
261
262
        return [576]


@dataclass
263
264
265
266
267
268
269
270
271
272
273
274
class MLACommonPrefillMetadata:
    """ Prefill Specific Metadata """

    @dataclass
    class ChunkedContextMetadata:
        # New for MLA (compared to FlashAttention)
        # For handling chunked prefill
        cu_seq_lens: torch.Tensor
        starts: torch.Tensor
        seq_tot: list[int]
        max_seq_lens: list[int]
        workspace: torch.Tensor
275

276
277
278
279
280
    block_table: torch.Tensor
    query_start_loc: torch.Tensor
    max_query_len: int
    chunked_context: Optional[ChunkedContextMetadata] = None

281

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
@dataclass
class MLACommonDecodeMetadata:
    block_table: torch.Tensor
    seq_lens: torch.Tensor


D = TypeVar("D", bound=MLACommonDecodeMetadata)


@dataclass
class MLACommonMetadata(Generic[D]):
    """Metadata for MLACommon.

    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """
298
299
300
301
302
303
304
305
306
307
308
309
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

310
311
312
313
314
315
    # New for MLA (compared to FlashAttention)
    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int

316
317
318
    # The dimension of the attention heads
    head_dim: Optional[int] = None

319
320
    decode: Optional[D] = None
    prefill: Optional[MLACommonPrefillMetadata] = None
321
322
323
324
325
326
327
328
329
330

    def __post_init__(self):
        supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
        if self.head_dim is not None and self.head_dim \
                not in supported_head_sizes:
            raise ValueError(
                f"Only {supported_head_sizes} are supported for head_dim,",
                f"received {self.head_dim}.")


331
M = TypeVar("M", bound=MLACommonMetadata)
332
333


334
class MLACommonMetadataBuilder(Generic[M]):
335
336
337
338
339
    """
    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """

340
341
    def __init__(self,
                 runner: "GPUModelRunner",
342
343
                 kv_cache_spec: AttentionSpec,
                 block_table: BlockTable,
344
345
346
                 metadata_cls: Optional[type[M]] = None):
        self.metadata_cls = metadata_cls \
            if metadata_cls is not None else MLACommonMetadata
347
348
349
350
351
        self.runner = runner
        scheduler_config = runner.scheduler_config
        model_config = runner.model_config
        cache_config = runner.cache_config
        self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
352
353
354
        self.num_heads = model_config.get_num_attention_heads(
            runner.parallel_config)
        self.mla_dims = get_mla_dims(model_config)
355
        self.aot_schedule = current_platform.is_cuda()
356
        self.kv_cache_spec = kv_cache_spec
357
358
359

        # Dont try to access the runner on AMD
        if self.aot_schedule:
360
            self.page_size = self.kv_cache_spec.block_size
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385

        if self.chunked_prefill_enabled:
            self.chunked_prefill_workspace_size = min(
                # Max sure there is enough for 8 full length request or at least
                # 4 pages of cache per request
                max(
                    8 * model_config.max_model_len, 4 *
                    scheduler_config.max_num_seqs * cache_config.block_size),
                # For long-context models try not to over-allocate limiting
                # kv-cache space, limiting it to 64k tokens,
                # which would result in the workspace being:
                #   2*(576)*(64*1024) = 144mb
                # (assuming 576 MLA head dim, and fp16)
                # which would result in up-projected context being
                #   2*(192*128)*(64*1024) = 3gb
                # (assuming 192 QK head dim, 128 heads, and fp16)
                128 * 1024)
            assert self.chunked_prefill_workspace_size >= \
                scheduler_config.max_num_seqs * cache_config.block_size
            self.chunked_prefill_workspace = torch.empty(
                (self.chunked_prefill_workspace_size,
                 model_config.get_head_size()),
                dtype=model_config.dtype,
                device=runner.device,
            )
386
        self.block_table = block_table
387
388

    def reorder_batch(self, input_batch: "InputBatch",
389
                      scheduler_output: "SchedulerOutput") -> bool:
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
        # We now want to reorder the batch so that the "decode" requests are and
        # the front and the "prefill" requests are at the using the least amount
        # swaps possible. (NOTE for now we loosely use "decode" to mean requests
        # where attention is likely memory-bound and "prefill" to mean requests
        # where attention is likely compute-bound, TODO(lucas): figure out a
        # better naming here)
        decodes = []
        prefills = []
        num_decode_tokens = 0
        num_prefill_tokens = 0

        for i, req_id in enumerate(input_batch.req_ids):
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
            # for now treat 1 scheduled token as "decode" even if its not,
            # we should update this to something like < 8 in the future but
            # currently the TritonMLA._forward_decode only supports
            # num_tokens = 1
            if num_tokens == 1:
                decodes.append(i)
                num_decode_tokens += num_tokens
            else:
                prefills.append(i)
                num_prefill_tokens += num_tokens

        # We hope that this is fairly minimal since decodes
        # should be around for a number of iterations so hopefully they are
        # relatively stationary (and new request are generally appended to the
        # persistent batch so already should be at the back)
        # To achieve this we loop over the decodes in descending order and
        # the prefills in ascending order. We swap decodes from the  "back"
        # i.e. past where the last decode should be in the reodorered with
        # prefills from the front of the batch.
        # `decodes` and `prefills` are already in ascending order just based on
        # the above loop
        num_decodes = len(decodes)
        num_prefills = len(prefills)
426
        modified_batch = False
427
428
429
430

        for i in range(1, min(num_decodes, num_prefills) + 1):
            # If the decode is at the "back" of the batch, i, we can swap it
            # with the prefill closest to the front of the batch
431
432
            decode_idx = decodes[num_decodes - i]
            if decode_idx < num_decodes:
433
434
                break

435
436
437
            input_batch.swap_states(prefills[i - 1], decode_idx)
            modified_batch = True

438
439
440
441
442
443
444
445
        # Save for next `build` call
        # TODO(lucas): this is a bit of a hack, we should probably have a
        # better way of doing this
        self._num_decodes = num_decodes
        self._num_prefills = num_prefills
        self._num_decode_tokens = num_decode_tokens
        self._num_prefill_tokens = num_prefill_tokens

446
447
        return modified_batch

448
449
    def _build_decode(self, block_table_tensor: torch.Tensor,
                      seq_lens: torch.Tensor):
450
        return MLACommonDecodeMetadata(
451
            block_table=block_table_tensor,
452
453
454
            seq_lens=seq_lens,
        )

455
    def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
456
457
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata) -> M:
458
459
        assert self._num_decodes + self._num_prefills == num_reqs

Simon Mo's avatar
Simon Mo committed
460
461
462
        # Note(simon): be careful about the CPU <> GPU memory movement in this
        # function. We should avoid GPU -> CPU sync as much as possible because
        # it blocks on all previous kernels.
463
        device = self.runner.device
464
465
466
        block_table = self.block_table
        block_table_tensor = block_table.get_device_tensor()[:num_reqs]
        slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
467
468
            device, non_blocking=True).long()

469
470
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
Simon Mo's avatar
Simon Mo committed
471

472
473
474
475
476
477
        prefill_metadata = None
        if self._num_prefills > 0:
            reqs_start = self._num_decodes  # prefill_start

            context_lens_cpu = self.runner.input_batch.\
                num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
Simon Mo's avatar
Simon Mo committed
478
479
            max_context_len_cpu = context_lens_cpu.max().item()
            num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
480
481
            prefill_query_start_loc = query_start_loc[
                reqs_start:] - query_start_loc[reqs_start]
482
483
484

            chunked_context_metadata = None
            if self.chunked_prefill_enabled and self._num_prefills > 0 \
Simon Mo's avatar
Simon Mo committed
485
                and max_context_len_cpu > 0:
486
487
488
489
490
491
492
493
                # NOTE: it is recommend you read the `Chunked Prefill` section
                # in the comment at the top of the file before trying to
                # understand the following code

                # currently we allocate an equal amount of workspace for each
                # prefill in the batch, we could probably use a more advanced
                # algorithm here and allocate more workspace to prefills with
                # longer context lengths
Simon Mo's avatar
Simon Mo committed
494
495
                max_context_chunk = (self.chunked_prefill_workspace_size //
                                     num_prefills_with_context_cpu)
496

497
498
499
500
501
502
                if self.aot_schedule:
                    # align max_context_chunk to page_size by rounding down,
                    # currently the `gather_cache` kernel cannot handle
                    # `context_chunk_starts` that are not aligned to page_size
                    max_context_chunk = round_down(max_context_chunk,
                                                   self.page_size)
503
504

                assert max_context_chunk > 0
Simon Mo's avatar
Simon Mo committed
505
                num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
506
507
508
509
510

                # if `max_context_chunk = 256`, `num_chunks = 3`, and
                #   `num_prefills_with_context = 4`, create a tensor that looks
                # like
                #  [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
Simon Mo's avatar
Simon Mo committed
511
512
                # Note(simon): this is done in CPU because of downstream's
                # of `to_list`.
513
                chunk_starts = \
Simon Mo's avatar
Simon Mo committed
514
                    torch.arange(num_chunks, dtype=torch.int32) \
515
516
                    .unsqueeze(1).expand(-1, self._num_prefills) \
                    * max_context_chunk
Simon Mo's avatar
Simon Mo committed
517
                chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
518
519
                                       chunk_starts + max_context_chunk)
                chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
Simon Mo's avatar
Simon Mo committed
520
521
522
523
524
525
526
527
528

                cu_seq_lens_cpu = torch.zeros(num_chunks,
                                              self._num_prefills + 1,
                                              dtype=torch.int32,
                                              pin_memory=True)
                torch.cumsum(chunk_seq_lens,
                             dim=1,
                             out=cu_seq_lens_cpu[:, 1:],
                             dtype=torch.int32)
529
530
531

                chunked_context_metadata = \
                    MLACommonPrefillMetadata.ChunkedContextMetadata(
Simon Mo's avatar
Simon Mo committed
532
533
                    cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
                    starts=chunk_starts.to(device, non_blocking=True),
534
535
536
537
538
539
540
541
542
                    seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
                    max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
                    workspace=self.chunked_prefill_workspace,
                )

                assert max(chunked_context_metadata.max_seq_lens) <= \
                    self.chunked_prefill_workspace_size

            prefill_metadata = MLACommonPrefillMetadata(
543
                block_table=block_table_tensor[reqs_start:, ...],
544
                query_start_loc=prefill_query_start_loc,
Simon Mo's avatar
Simon Mo committed
545
                max_query_len=max_query_len,
546
547
548
549
550
551
                chunked_context=chunked_context_metadata,
            )

        decode_metadata = None
        if self._num_decodes > 0:
            decode_metadata = self._build_decode(
552
                block_table_tensor=block_table_tensor[:self._num_decodes, ...],
553
554
555
556
                seq_lens=seq_lens[:self._num_decodes],
            )

        return self.metadata_cls(
557
558
559
560
561
562
563
564
            num_actual_tokens=num_actual_tokens,
            query_start_loc=query_start_loc,
            slot_mapping=slot_mapping,
            head_dim=self.runner.model_config.get_head_size(),
            # MLACommonMetadata Chunk prefill specific
            num_decodes=self._num_decodes,
            num_decode_tokens=self._num_decode_tokens,
            num_prefills=self._num_prefills,
565
566
            prefill=prefill_metadata,
            decode=decode_metadata,
567
568
        )

569
570
571
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return False

572

573
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
574
575
576
577
578
579
580
581
582
583
584
    """
    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
585
        alibi_slopes: Optional[list[float]],
586
587
        sliding_window: Optional[int],
        kv_cache_dtype: str,
588
        blocksparse_params: Optional[dict[str, Any]],
589
590
        logits_soft_cap: Optional[float],
        attn_type: str,
591
        kv_sharing_target_layer_name: Optional[str],
592
593
594
595
596
597
598
599
600
        # MLA Specific Arguments
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
    ) -> None:
601
602
603
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported for MLA")

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_head_dim
        self.v_head_dim = v_head_dim
        self.kv_b_proj = kv_b_proj
        self.vllm_flash_attn_version = get_flash_attn_version()

        # Handle the differences between the flash_attn_varlen from flash_attn
        # and the one from vllm_flash_attn. The former is used on RoCM and the
        # latter has an additional parameter to control FA2 vs FA3
        self.flash_attn_varlen_func = flash_attn_varlen_func
623
        self.vllm_flash_attn_version = get_flash_attn_version()
624
625
626
627
        if self.vllm_flash_attn_version is not None:
            self.flash_attn_varlen_func = \
                functools.partial(flash_attn_varlen_func,
                                  fa_version=self.vllm_flash_attn_version)
zhuwenwen's avatar
zhuwenwen committed
628
629
        
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
630

631
632
633
634
635
636
        # For MLA the v head dim is smaller than qk head dim so we pad out
        # v with 0s to match the qk head dim for attention backends that do
        # not support different headdims
        # We don't need to pad V if we are on a hopper system with FA3
        self._pad_v = self.vllm_flash_attn_version is None or not (
            self.vllm_flash_attn_version == 3
zhuwenwen's avatar
zhuwenwen committed
637
638
            and current_platform.get_device_capability()[0] == 9
            and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 )
639
640
641
642
643
644
645
646
647
648

    def _flash_attn_varlen_diff_headdims(self,
                                         q,
                                         k,
                                         v,
                                         return_softmax_lse=False,
                                         softmax_scale=None,
                                         **kwargs):
        maybe_padded_v = v
        if self._pad_v:
zhuwenwen's avatar
zhuwenwen committed
649
650
            # maybe_padded_v = torch.nn.functional.pad(
            #     v, [0, q.shape[-1] - v.shape[-1]], value=0)
651
            maybe_padded_v = torch.nn.functional.pad(
zhuwenwen's avatar
zhuwenwen committed
652
653
                    v, [0, q.shape[-1] - v.shape[-1]- 32], value=0)
            maybe_padded_v = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2])
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

        attn_out = self.flash_attn_varlen_func(
            q=q,
            k=k,
            v=maybe_padded_v,
            return_softmax_lse=return_softmax_lse,
            softmax_scale=softmax_scale,
            **kwargs,
        )

        # Unpack the output if there is multiple results
        lse = None
        if isinstance(attn_out, tuple):
            attn_out, lse = attn_out[0], attn_out[1]

        # Remain consistent with old `flash_attn_varlen_func` where there
        # is only one output tensor if `return_softmax_lse` is False.
        if return_softmax_lse:
            return attn_out, lse
        return attn_out

675
    def _v_up_proj(self, x):
676
677
678
679
680
        # Convert from (B, N, L) to (N, B, L)
        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
        # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
        x = torch.bmm(x, self.W_UV)
        # Convert from (N, B, V) to (B, N * V)
681
        return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
682

683
    def process_weights_after_loading(self, act_dtype: torch.dtype):
684
685

        def get_layer_weight(layer):
686
687
688
689
690
691
692
            WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
            for attr in WEIGHT_NAMES:
                if hasattr(layer, attr):
                    return getattr(layer, attr)
            raise AttributeError(
                f"Layer '{layer}' has no recognized weight attribute:"
                f" {WEIGHT_NAMES}.")
693
694
695
696
697
698
699
700
701
702
703
704
705

        def get_and_maybe_dequant_weights(layer: LinearBase):
            if not isinstance(layer.quant_method, UnquantizedLinearMethod):
                # NOTE: This should only be used offline, since it's O(N^3)
                eye = torch.eye(layer.input_size_per_partition,
                                dtype=act_dtype,
                                device=get_layer_weight(layer).device)
                dequant_weights = layer.quant_method.apply(layer,
                                                           eye,
                                                           bias=None)
                del eye
                # standardize to (output, input)
                return dequant_weights.T
706
            return layer.weight if not envs.VLLM_USE_NN else layer.weight.T
707

708
709
710
        # we currently do not have quantized bmm's which are needed for
        # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
        # the bmm's in 16-bit, the extra memory overhead of this is fairly low
zhuwenwen's avatar
zhuwenwen committed
711
712
713
714
        if self.use_llama_nn and isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod):
            kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
        else:
            kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        assert kv_b_proj_weight.shape == (
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
                f"{kv_b_proj_weight.shape=}, "
                f"{self.kv_lora_rank=}, "
                f"{self.num_heads=}, "
                f"{self.qk_nope_head_dim=}, "
                f"{self.v_head_dim=}")
        kv_b_proj_weight = kv_b_proj_weight.view(
            self.kv_lora_rank,
            self.num_heads,
            self.qk_nope_head_dim + self.v_head_dim,
        )

        W_UK, W_UV = kv_b_proj_weight.split(
            [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

732
733
734
735
        # Convert from (L, N, V) to (N, L, V)
        self.W_UV = W_UV.transpose(0, 1)
        # Convert from (L, N, P) to (N, P, L)
        self.W_UK_T = W_UK.permute(1, 2, 0)
736
737
738
739
740
741
742

    def _compute_prefill_context(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
    ):
743
744
745
        assert attn_metadata.prefill is not None
        prefill_metadata = attn_metadata.prefill
        assert prefill_metadata.chunked_context is not None
746
747

        output = None
748
749
        iters = len(prefill_metadata.chunked_context.seq_tot)
        workspace = prefill_metadata.chunked_context.workspace
750
751

        for i in range(iters):
752
            toks = prefill_metadata.chunked_context.seq_tot[i]
753
754
755
756

            ops.gather_cache(
                src_cache=kv_c_and_k_pe_cache,
                dst=workspace,
757
758
                block_table=prefill_metadata.block_table,
                cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
759
                batch_size=attn_metadata.num_prefills,
760
                seq_starts=prefill_metadata.chunked_context.starts[i],
761
762
763
            )

            kv_c_normed = workspace[:toks]\
764
                [..., :self.kv_lora_rank]
765
766
767
768
769
770
771
772
773
774
775
            k_pe = workspace[:toks]\
                [..., self.kv_lora_rank:].unsqueeze(1)

            kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
                -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = kv_nope\
                .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

            k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
                          dim=-1)

776
777
            attn_output, attn_softmax_lse = \
                self._flash_attn_varlen_diff_headdims(
778
779
                q=q,
                k=k,
780
                v=v,
781
782
783
784
                cu_seqlens_q=prefill_metadata.query_start_loc,
                cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
                max_seqlen_q=prefill_metadata.max_query_len,
                max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
                softmax_scale=self.scale,
                causal=False,  # Context is unmasked
                return_softmax_lse=True,
            )

            if output is None:
                output = attn_output
                output_lse = attn_softmax_lse
            else:
                output_tmp = torch.empty_like(output)
                output_lse_tmp = torch.empty_like(output_lse)
                merge_attn_states(
                    output=output_tmp,
                    output_lse=output_lse_tmp,
                    prefix_output=output,
                    prefix_lse=output_lse,
                    suffix_output=attn_output,
                    suffix_lse=attn_softmax_lse,
                )
                output = output_tmp
                output_lse = output_lse_tmp

        return output, output_lse

    def _forward_prefill(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
    ) -> torch.Tensor:
817
818
819
        assert attn_metadata.prefill is not None

        has_context = attn_metadata.prefill.chunked_context is not None
820
821
822
823
824
825
826
        kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
            -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv_nope\
            .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

827
        output = self._flash_attn_varlen_diff_headdims(
828
829
            q=q,
            k=k,
830
            v=v,
831
832
833
834
            cu_seqlens_q=attn_metadata.prefill.query_start_loc,
            cu_seqlens_k=attn_metadata.prefill.query_start_loc,
            max_seqlen_q=attn_metadata.prefill.max_query_len,
            max_seqlen_k=attn_metadata.prefill.max_query_len,
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
            softmax_scale=self.scale,
            causal=True,
            return_softmax_lse=has_context,
        )

        if has_context:
            suffix_output, suffix_lse = output
            context_output, context_lse = self._compute_prefill_context( \
                q, kv_c_and_k_pe_cache, attn_metadata)

            output = torch.empty_like(suffix_output)
            merge_attn_states(
                output=output,
                prefix_output=context_output,
                prefix_lse=context_lse,
                suffix_output=suffix_output,
                suffix_lse=suffix_lse,
            )

854
855
856
857
        # unpad if necessary
        if self._pad_v:
            output = output[..., :v.shape[-1]]

858
        return output.flatten(start_dim=-2)
859
860
861
862

    @abstractmethod
    def _forward_decode(
        self,
863
        ql_nope: torch.Tensor,
864
865
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
866
        attn_metadata: M,
867
868
869
870
871
872
    ) -> torch.Tensor:
        raise NotImplementedError

    def forward(
        self,
        layer: AttentionLayer,
873
        q: torch.Tensor,
874
875
876
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
877
        attn_metadata: M,
878
879
880
881
882
883
        output: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
884
885
886
887
            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)
888
889
890
891
892
893

        num_actual_toks = attn_metadata.num_actual_tokens

        # Inputs and outputs may be padded for CUDA graphs
        output_padded = output
        output = output[:num_actual_toks, ...]
894
        q = q[:num_actual_toks, ...]
895
896
897
898
899
900
901
902
903
904
905
        k_c_normed = k_c_normed[:num_actual_toks, ...]
        k_pe = k_pe[:num_actual_toks, ...]

        assert attn_metadata.num_decodes is not None and \
            attn_metadata.num_prefills is not None and \
            attn_metadata.num_decode_tokens is not None

        has_decode = attn_metadata.num_decodes > 0
        has_prefill = attn_metadata.num_prefills > 0
        num_decode_tokens = attn_metadata.num_decode_tokens

906
        decode_q = q[:num_decode_tokens]
907

908
        prefill_q = q[num_decode_tokens:]
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
        prefill_k_pe = k_pe[num_decode_tokens:]
        prefill_k_c_normed = k_c_normed[num_decode_tokens:]

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=layer._k_scale,
            )

        if has_prefill:
            output[num_decode_tokens:] = self._forward_prefill(
                prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
                attn_metadata)

        if has_decode:
929
930
931
932
933
934
935
936
937
938
            assert attn_metadata.decode is not None
            decode_q_nope, decode_q_pe = decode_q.split(
                [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
            # Convert from (B, N, P) to (N, B, P)
            decode_q_nope = decode_q_nope.transpose(0, 1)
            # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
            decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
            # Convert from (N, B, L) to (B, N, L)
            decode_ql_nope = decode_ql_nope.transpose(0, 1)

939
            output[:num_decode_tokens] = self._forward_decode(
940
                decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
941

zhuwenwen's avatar
zhuwenwen committed
942
        return output_padded