utils.py 37.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import abc
4
import enum
5
import functools
6
from abc import abstractmethod
7
from dataclasses import dataclass, fields, make_dataclass
8
9
10
11
12
13
14
15
16
17
18
19
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Generic,
    Literal,
    Optional,
    Protocol,
    TypeVar,
    Union,
    get_args,
)
20

21
import numpy as np
22
import torch
23
from typing_extensions import runtime_checkable
24

25
from vllm.config import VllmConfig, get_layers_from_vllm_config
26
27
from vllm.utils import cdiv

28
if TYPE_CHECKING:
29
    from vllm.attention.backends.abstract import AttentionImpl
30
31
32
    from vllm.v1.core.sched.output import SchedulerOutput
    from vllm.v1.worker.gpu_input_batch import InputBatch

33
import vllm.envs as envs
34
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
35
from vllm.attention.layer import Attention
36
from vllm.distributed.kv_transfer.kv_connector.utils import (
37
38
    get_kv_connector_cache_layout,
)
39
from vllm.logger import init_logger
40
from vllm.v1.kv_cache_interface import AttentionSpec
41
from vllm.v1.worker.ubatch_utils import UBatchSlice
42
43

logger = init_logger(__name__)
44
45
46
KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None

47
48
PAD_SLOT_ID = -1

49
50
51

def is_valid_kv_cache_layout(value: str) -> bool:
    return value in get_args(KVCacheLayoutType)
52

53
54
55
56

@dataclass
class CommonAttentionMetadata:
    """
57
58
    Per-batch attention metadata, shared across layers and backends.
    AttentionMetadataBuilder instances use it to construct per-layer metadata.
59

60
    For many of the tensors we keep both GPU and CPU versions.
61
62
63
    """

    query_start_loc: torch.Tensor
64
    query_start_loc_cpu: torch.Tensor
65
    """(batch_size + 1,), the start location of each request in query Tensor"""
66

67
    seq_lens: torch.Tensor
68
    seq_lens_cpu: torch.Tensor
69
70
    """(batch_size,), the length of each request including both computed tokens
    and newly scheduled tokens"""
71

72
73
74
    num_computed_tokens_cpu: torch.Tensor
    """(batch_size,), the number of computed tokens for each request"""

75
76
77
78
79
80
    num_reqs: int
    """Number of requests"""
    num_actual_tokens: int
    """Total number of tokens in batch"""
    max_query_len: int
    """Longest query in batch"""
81
82
    max_seq_len: int
    """Longest context length in batch"""
83

84
85
86
    block_table_tensor: torch.Tensor
    slot_mapping: torch.Tensor

87
88
    causal: bool = True

89
90
91
92
    # Needed by FastPrefillAttentionBuilder
    logits_indices_padded: Optional[torch.Tensor] = None
    num_logits_indices: Optional[int] = None

93
94
95
    # Needed by CrossAttentionBuilder
    encoder_seq_lens: Optional[np.ndarray] = None

96

97
98
99
100
101
def slice_query_start_locs(
    query_start_loc: torch.Tensor,
    request_slice: slice,
) -> torch.Tensor:
    """
102
    Creates a new query_start_loc that corresponds to the requests in
103
104
105
106
107
    request_slice.

    Note: This function creates a new tensor to hold the new query_start_locs.
    This will break cudagraph compatibility.
    """
108
109
110
111
    return (
        query_start_loc[request_slice.start : request_slice.stop + 1]
        - query_start_loc[request_slice.start]
    )
112
113
114


def _make_metadata_with_slice(
115
116
    ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
117
    """
118
    This function creates a new CommonAttentionMetadata that corresponds to
119
120
121
    the requests included in ubatch_slice
    """

122
    assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
123

124
125
126
    request_slice = ubatch_slice.request_slice
    token_slice = ubatch_slice.token_slice

127
128
129
130
131
132
    start_locs = attn_metadata.query_start_loc_cpu
    first_req = request_slice.start
    first_tok = token_slice.start
    last_req = request_slice.stop - 1
    last_tok = token_slice.stop - 1

133
    assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
134
        "Token slice start outside of first request"
135
136
    )
    assert start_locs[last_req] <= last_tok < start_locs[last_req + 1], (
137
        "Token slice end outside of last request"
138
    )
139
140
141
142
143
144
145
146
147

    # If the "middle" request has tokens in both ubatches, we have to split it.
    # If ubatch_slice is the first ubatch then we will be splitting the last
    # request. If it's the second microbatch, then we will be splitting the
    # first request
    splits_first_request = first_tok > start_locs[first_req]
    splits_last_request = last_tok < start_locs[last_req + 1] - 1

    query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
148
149
150
    query_start_loc = slice_query_start_locs(
        attn_metadata.query_start_loc, request_slice
    )
151

152
    assert len(query_start_loc) >= 2, (
153
154
        f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
    )
155

156
157
158
159
    if splits_first_request:
        tokens_skipped = first_tok - start_locs[first_req]
        query_start_loc[1:] -= tokens_skipped
        query_start_loc_cpu[1:] -= tokens_skipped
160
161
    seq_lens = attn_metadata.seq_lens[request_slice]
    seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
162
163
164
165
166
167
168
169
170
171
172
173
174

    if splits_last_request:
        tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
        query_start_loc[-1] -= tokens_skipped
        query_start_loc_cpu[-1] -= tokens_skipped

        # Make sure we don't modify the seq_lens tensors
        #  (not cudagraph compatible)
        seq_lens = seq_lens.clone()
        seq_lens_cpu = seq_lens_cpu.clone()
        seq_lens[-1] -= tokens_skipped
        seq_lens_cpu[-1] -= tokens_skipped

175
    max_seq_len = int(seq_lens_cpu.max())
176
    num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
177
178
179
180

    num_requests = request_slice.stop - request_slice.start
    num_actual_tokens = token_slice.stop - token_slice.start
    max_query_len = int(
181
182
        torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
    )
183

184
185
186
187
188
    # This is to account for the case where we are in a dummy
    # run and query_start_loc_cpu is full of 0s
    if max_query_len == 0:
        max_query_len = attn_metadata.max_query_len

189
190
191
192
193
194
195
196
197
198
199
200
    block_table_tensor = attn_metadata.block_table_tensor[request_slice]
    slot_mapping = attn_metadata.slot_mapping[token_slice]

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens_cpu,
        num_computed_tokens_cpu=num_computed_tokens_cpu,
        num_reqs=num_requests,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
201
        max_seq_len=max_seq_len,
202
203
204
205
206
207
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
    )


def split_attn_metadata(
208
    ubatch_slices: list[UBatchSlice],
209
210
211
    common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
    """
212
    Creates a new CommonAttentionMetadata instance that corresponds to the
213
    requests for each UBatchSlice in ubatch_slices.
214
215
216
217
218

    Note: This function does not modify common_attn_metadata
    """
    results = []
    for ubatch_slice in ubatch_slices:
219
        results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
220

221
222
223
    return results


224
225
226
M = TypeVar("M")


227
class AttentionCGSupport(enum.Enum):
228
    """Constants for the cudagraph support of the attention backend
229
230
231
    Here we do not consider the cascade attention, as currently
    it is never cudagraph supported."""

232
233
234
235
236
237
238
239
    ALWAYS = 3
    """Cudagraph always supported; supports mixed-prefill-decode"""
    UNIFORM_BATCH = 2
    """Cudagraph supported for batches the only contain query lengths that are
    the same, this can be used for spec-decode 
        i.e. "decodes" are 1 + num_speculative_tokens"""
    UNIFORM_SINGLE_TOKEN_DECODE = 1
    """Cudagraph supported for batches the only contain query_len==1 decodes"""
240
241
242
243
    NEVER = 0
    """NO cudagraph support"""


244
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
245
    # Does this backend/builder support CUDA Graphs for attention (default: no).
246
    cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
247
248
249
    # Does this backend/builder reorder the batch?
    # If not, set this to None. Otherwise set it to the query
    # length that will be pulled into the front of the batch.
250
    reorder_batch_threshold: Optional[int] = None
251
252

    @abstractmethod
253
254
255
256
257
258
259
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
260
        self.kv_cache_spec = kv_cache_spec
261
262
263
        self.layer_names = layer_names
        self.vllm_config = vllm_config
        self.device = device
264

265
    def _init_reorder_batch_threshold(
266
267
        self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False
    ) -> None:
268
        self.reorder_batch_threshold = reorder_batch_threshold
269
        if self.reorder_batch_threshold is not None and supports_spec_as_decode:
270
271
272
273
            # If the backend supports spec-as-decode kernels, then we can set
            # the reorder_batch_threshold based on the number of speculative
            # tokens from the config.
            speculative_config = self.vllm_config.speculative_config
274
275
276
277
278
            if (
                speculative_config is not None
                and speculative_config.num_speculative_tokens is not None
            ):
                self.reorder_batch_threshold = (
279
                    1 + speculative_config.num_speculative_tokens
280
                )
281

282
    @abstractmethod
283
284
285
286
287
288
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> M:
289
290
291
        """
        Central method that builds attention metadata.
        Some builders (MLA) require reorder_batch to be called prior to build.
292

293
294
295
296
297
298
        Args:
            common_prefix_len: The length of the common prefix of the batch.
            common_attn_metadata: The common attention metadata.
            fast_build: The meta-data will prioritize speed of building over
                then speed at execution. Can be used for spec-decode where the
                result of a build call may only be used for few layers/iters.
299
300
301
        """
        raise NotImplementedError

302
303
304
    def reorder_batch(
        self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
    ) -> bool:
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        """
        Update the order of requests in the batch based on the attention
        backend's needs. For example, some attention backends (namely MLA) may
        want to separate requests based on if the attention computation will be
        compute-bound or memory-bound.

        Args:
            input_batch: input batch
            scheduler_output: scheduler output.

        Returns:
            True if the batch was modified, False otherwise.
        """
        raise NotImplementedError

320
    def build_for_cudagraph_capture(
321
322
        self, common_attn_metadata: CommonAttentionMetadata
    ) -> M:
323
324
325
326
327
        """
        Build attention metadata for CUDA graph capture. Uses build by default.
        Subclasses that override this method should call self.build or
        super().build_for_cudagraph_capture.
        """
328
329
330
        return self.build(
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
331

332
333
334
335
336
337
338
    def build_for_drafting(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        draft_index: int,
    ) -> M:
        """
        Build attention metadata for draft model. Uses build by default.
339

340
341
342
343
344
345
346
347
        Args:
            common_attn_metadata: The common attention metadata.
            draft_index: The index of the current draft operation.
                When speculating a chain of tokens, this index refers to the
                draft attempt for the i-th token.
                For tree-based attention, this index instead refers to the
                draft attempt for the i-th level in the tree of tokens.
        """
348
349
350
351
352
        return self.build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
            fast_build=True,
        )
353

354
355
356
357
358
359
360
361
    def use_cascade_attention(
        self,
        common_prefix_len: int,
        query_lens: np.ndarray,
        num_query_heads: int,
        num_kv_heads: int,
        use_alibi: bool,
        use_sliding_window: bool,
362
        use_local_attention: bool,
363
364
365
366
        num_sms: int,
    ) -> bool:
        return False

367

368
369
@functools.lru_cache
def get_kv_cache_layout():
370
    # Format specified by the code.
371
    global _KV_CACHE_LAYOUT_OVERRIDE
372
373
374

    if _KV_CACHE_LAYOUT_OVERRIDE is not None:
        cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
375
376
377
378
379
        logger.info_once(
            "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
            "Setting KV cache layout to %s.",
            cache_layout,
        )
380
381
382
        return cache_layout

    # Format specified by the user.
383
    cache_layout = envs.VLLM_KV_CACHE_LAYOUT
384
    # When neither the user nor the override specified a layout, get default
385
    if cache_layout is None:
386
        cache_layout = get_kv_connector_cache_layout()
387
    else:
388
        assert is_valid_kv_cache_layout(cache_layout)
389
390
391
392
393
        logger.info_once(
            "`VLLM_KV_CACHE_LAYOUT` environment variable "
            "detected. Setting KV cache layout to %s.",
            cache_layout,
        )
394
    return cache_layout
395
396


397
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
398
399
400
401
    global _KV_CACHE_LAYOUT_OVERRIDE
    _KV_CACHE_LAYOUT_OVERRIDE = cache_layout


402
403
404
405
@dataclass
class PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
406
407
408
    the same values for the following hyperparameters. Should not be used for
    trtllm-gen backend since it supports different values for the following
    hyperparameters.
409
410
411
412
413
    """

    window_left: int
    logits_soft_cap: Optional[float]
    sm_scale: float
414
    has_sinks: bool = False
415
416
417


def get_per_layer_parameters(
418
419
    vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
) -> dict[str, PerLayerParameters]:
420
    """
421
    Scan layers in `layer_names` and determine some hyperparameters
422
423
424
    to use during `plan`.
    """

425
    layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
426
427
428
429
430
431
432
433
434
435
436
    per_layer_params: dict[str, PerLayerParameters] = {}

    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, cls_)

        # Infer hyperparameters from the attention layer
        window_size = getattr(impl, "sliding_window", None)
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = getattr(impl, "logits_soft_cap", None)
        sm_scale = impl.scale
437
        has_sinks = getattr(impl, "sinks", None) is not None
438

439
440
441
        per_layer_params[key] = PerLayerParameters(
            window_left, logits_soft_cap, sm_scale, has_sinks
        )
442
443
444
445
446

    return per_layer_params


def infer_global_hyperparameters(
447
448
    per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
449
    """
450
    Currently, FlashInfer backend other than trtllm-gen
451
    only support models in which all layers share
452
453
454
455
456
457
458
459
460
461
462
463
464
    the same values for the following hyperparameters:
    - `window_left`
    - `logits_soft_cap`
    - `sm_scale`

    So this function asserts that all layers share the same values for these
    hyperparameters and returns the global values.
    """

    assert len(per_layer_params) > 0, "No attention layers found in the model."

    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]
465
466

    # trtllm attention doesn't need global hyper params so disable the check
467
    if not envs.VLLM_USE_TRTLLM_ATTENTION:
468
469
470
        for params in param_sets:
            if params.window_left != global_params.window_left:
                raise ValueError(
471
472
473
                    "Window left is not the same for all layers. "
                    "One potential fix is to set disable_sliding_window=True"
                )
474
475
476
477
            assert params == global_params, (
                "FlashInfer backend currently only supports models in which all"
                "layers share the same values "
                "for the following hyperparameters:"
478
479
                "`window_left`, `logits_soft_cap`, `sm_scale`."
            )
480
481
482
483

    return global_params


484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
#   q_seqlens  = [4, 10, 5]
#   kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
#  for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
#   batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 | 1 1 1 1 1
#               3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
#  attention mask like:
#   batch idx: 0  (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 |         1
#               3 |         1 1
#
# We can simulate this mask using standard flash-attention by breaking the
#  sequences into local ("virtual") batches, where each local batch item is a
#  local attention block, so in this case batch idx 0 would be broken up into:
#
#   local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4)  (batch 0)
#        k_toks >   0 1 2 3
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#   local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
#        k_toks >   4 5
#        q_toks v  _____________
#               2 | 1
#               3 | 1 1
#
# e.g. if we have:
#   attn_chunk_size = 4
#   query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
#                           __b0__  ______b1______  __b2__ < orig batch indices
#   q_seqlens_local    = [   2,  2,  1,  4,  4,  1,  4,  1]
#   cu_seqlens_q_local = [0, 4,  6, 10, 14, 18, 19, 23, 24]
#   seqlens_k_local    = [   4,  2,  4,  4,  4,  1,  4,  1]
#   block_table_local  : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
    attn_chunk_size: int,
538
    common_attn_metadata: CommonAttentionMetadata,
539
    block_size: int = 0,
540
541
542
543
544
545
) -> CommonAttentionMetadata:
    query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
    seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
    block_table = common_attn_metadata.block_table_tensor
    device = common_attn_metadata.query_start_loc.device

546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
    actual_batch_size = seq_lens_np.shape[0]

    # Handle if we are starting in the middle of a local attention block,
    #  we assume q_seqlens > 0 (for all elements), for each batch idx we compute
    #  the number of tokens that are not in the first local attention block and
    #  then we can simply use a cdiv for the rest.
    # For example if we have:
    #   attn_chunk_size = 4
    #   q_seqlens = [4, 10, 5]
    #   k_seqlens = [6, 17, 9]
    # Then we would get:
    #   new_tokens_in_first_block = [2, 1, 4]
    #   local_blocks = [2, 4, 2]
    q_tokens_in_first_block = np.minimum(
561
562
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
    ).astype(np.int32)
563
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
564
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584

    # Once we know the number of local blocks we can compute the request spans
    #  for each batch idx, we can figure out the number of "virtual" requests we
    #  have to make,
    # For the above example we would get:
    #   seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
    #
    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
    #   (TODO: max a utility to share this code with _prepare_inputs)
    # arange step 1. [2, 4, 2] -> [2, 6, 8]
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
    # Then we can compute the seqlens_q_local, handling the fact that the
    #  first and last blocks could be partial
585
    seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
586
587
588
589
    # set the first block since this may be a partial block
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    # set the remaining blocks
    seqlens_q_local[arange > 0] = np.minimum(
590
591
        seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
    )[arange > 0]
592
593

    # convert from q_seqlens to cu_seqlens_q
594
595
596
    cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
    np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
    cu_seqlens_q_local[0] = 0
597
598
599
600
601
602

    # compute the seqlens_k_local,
    #  basically a full local attention block for all but the last block in each
    #  batch
    # For our example this will be:
    #   seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
603
    seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
604
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
605
    num_computed_tokens_local = seqlens_k_local - seqlens_q_local
606

607
608
609
    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
        rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
    )
610
611
612
613
    # For the example the local attention blocks start at:
    #                           _b0_  _____b1_____  _b2_
    #   k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
    block_starts = k_seqstarts_absolute // block_size
614
615
616
    assert attn_chunk_size % block_size == 0, (
        f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}"
    )
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
    pages_per_local_batch = attn_chunk_size // block_size

    # Create a block_table for the local attention blocks
    # For out example if we have a block-table like (assuming block_size=2):
    #   block_table = [
    #     [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],  < batch 0
    #     [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  < batch 1
    #     [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],  < batch 2
    #   ]
    # Then for the local batches we would want a block-table like
    #   block_table_local = [
    #     [  0,  1 ], < local-batch 0, (batch 0, starting from k[0])
    #     [  2,  3 ], < local-batch 1, (batch 0, starting from k[4])
    #     [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
    #     [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
    #     [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
    #     [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
    #     [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
    #     [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
    #   ]
637
638
639
640
641
642
643
644
    block_indices = block_starts[:, None] + np.arange(
        pages_per_local_batch, dtype=np.int32
    )
    block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
    batch_indices = np.repeat(
        np.arange(actual_batch_size, dtype=np.int32),
        local_blocks * pages_per_local_batch,
    )
645
646
647
648
649
650
651

    # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
    # regression when using numpy arrays (batch and block indices) to index into
    # torch tensor (block_table). As a workaround, convert numpy arrays to torch
    # tensor first, which recovers perf.
    batch_indices_torch = torch.from_numpy(batch_indices)
    block_indices_torch = torch.from_numpy(block_indices)
652
653
654
    block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
        virtual_batches, -1
    )
655

656
657
    query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
    seq_lens_cpu = torch.from_numpy(seqlens_k_local)
658
    max_seq_len = int(seq_lens_cpu.max())
659
660
661

    return CommonAttentionMetadata(
        query_start_loc_cpu=query_start_loc_cpu,
662
        query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
663
664
665
666
667
668
        seq_lens_cpu=seq_lens_cpu,
        seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
        num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
        num_reqs=len(seq_lens_cpu),
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        max_query_len=seqlens_q_local.max(),
669
        max_seq_len=max_seq_len,
670
671
        block_table_tensor=block_table_local,
        slot_mapping=common_attn_metadata.slot_mapping,
672
        causal=True,
673
    )
674
675


676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
def make_kv_sharing_fast_prefill_common_attn_metadata(
    common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
    if common_attn_metadata.max_query_len == 1:
        # All requests are decode (assume 1 token for now)
        # Skip computing fast prefill path
        return common_attn_metadata

    assert common_attn_metadata.logits_indices_padded is not None
    assert common_attn_metadata.num_logits_indices is not None

    logits_indices_padded = common_attn_metadata.logits_indices_padded
    num_logits_indices = common_attn_metadata.num_logits_indices
    # Get rid of CUDAGraph padding, if any
    logits_indices = logits_indices_padded[:num_logits_indices]
    num_reqs = common_attn_metadata.num_reqs
    query_start_loc = common_attn_metadata.query_start_loc
    seq_lens = common_attn_metadata.seq_lens
    # Example inputs
    # num_reqs: 3
    # generation_indices:  [14, 18, 19, 27]
    # query_start_loc: [0, 15, 20, 28]
    # seq_lens:        [41, 31, 40]

    # Find how many decode indices belong to each request
    # request_ids: [0, 1, 1, 2]
702
    request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
703
704
705
706
707
708
709

    # Figure out how many tokens are in each request
    # num_decode_tokens: [1, 2, 1]
    num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)

    # Calculate new query_start_loc with tokens in generation_indices
    # decode_query_start_loc: [0, 1, 3, 4]
710
711
712
    decode_query_start_loc = torch.empty(
        num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype
    )
713
714
715
716
717
718
719
720

    decode_query_start_loc[0] = 0
    decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
    decode_max_query_len = int(num_decode_tokens.max().item())
    total_num_decode_tokens = int(num_decode_tokens.sum().item())

    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=decode_query_start_loc,
721
        query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
        num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
        num_reqs=num_reqs,
        num_actual_tokens=total_num_decode_tokens,
        max_query_len=decode_max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        block_table_tensor=common_attn_metadata.block_table_tensor,
        slot_mapping=common_attn_metadata.slot_mapping,
        causal=True,
    )
    return common_attn_metadata


736
def subclass_attention_backend(
737
738
739
    name_prefix: str,
    attention_backend_cls: type[AttentionBackend],
    builder_cls: type[AttentionMetadataBuilder[M]],
740
741
742
743
744
745
) -> type[AttentionBackend]:
    """
    Return a new subclass where `get_builder_cls` returns `builder_cls`.
    """
    name: str = name_prefix + attention_backend_cls.__name__  # type: ignore

746
747
748
    return type(
        name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
    )
749
750


751
def split_decodes_and_prefills(
752
753
754
755
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
    require_uniform: bool = False,
) -> tuple[int, int, int, int]:
756
757
758
759
760
761
762
763
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.
764
765
766
        require_uniform: If True, requires that all decode requests have the
            same query length. When set, some queries may be considered prefills
            even if they are <= decode_threshold, in order to ensure uniformity.
767
768
769
770
771
772
773
774
775
776
777
778

    Returns:
        num_decodes: The number of decode requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu

779
780
781
    if max_query_len <= decode_threshold and (
        not require_uniform or decode_threshold <= 1
    ):
782
783
784
        return num_reqs, 0, num_tokens, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
785
786
787
788
789
790
791
792
793
    if query_lens[0].item() > decode_threshold:
        # first request is not decode, so no decode requests
        return 0, num_reqs, 0, num_tokens

    if require_uniform:
        is_prefill = query_lens != query_lens[0]
    else:
        is_prefill = query_lens > decode_threshold

794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
    if not torch.any(is_prefill):
        return num_reqs, 0, num_tokens, 0

    first_prefill = is_prefill.int().argmax(dim=-1).item()
    assert torch.all(query_lens[:first_prefill] <= decode_threshold)
    num_decodes = first_prefill
    num_prefills = num_reqs - num_decodes
    num_decode_tokens = query_start_loc[first_prefill].item()
    num_prefill_tokens = num_tokens - num_decode_tokens
    return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)


def reorder_batch_to_split_decodes_and_prefills(
    input_batch: "InputBatch",
    scheduler_output: "SchedulerOutput",
    decode_threshold: int = 1,
) -> bool:
    """
    Reorders the batch to split into prefill and decode requests; places all
    requests with <= decode_threshold tokens at the front of the batch.
814

815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
    Returns:
        True if the batch was modified, False otherwise.
    """
    # We now want to reorder the batch so that the "decode" requests are at
    # the front and the "prefill" requests are at the back using the least
    # amount of swaps possible. (NOTE for now we loosely use "decode" to mean
    # requests where attention is likely memory-bound and "prefill" to mean
    # requests where attention is likely compute-bound, TODO(lucas): figure out
    # a better naming here)
    decodes = []
    prefills = []
    num_decode_tokens = 0
    num_prefill_tokens = 0

    for i, req_id in enumerate(input_batch.req_ids):
        num_tokens = scheduler_output.num_scheduled_tokens[req_id]
831
        # for now treat 1 scheduled token as "decode" even if it's not,
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
        # we should update this to something like < 8 in the future but
        # currently the TritonMLA._forward_decode only supports
        # num_tokens = 1
        if num_tokens <= decode_threshold:
            decodes.append(i)
            num_decode_tokens += num_tokens
        else:
            prefills.append(i)
            num_prefill_tokens += num_tokens

    # We hope that this is fairly minimal since decodes
    # should be around for a number of iterations so hopefully they are
    # relatively stationary (and new request are generally appended to the
    # persistent batch so already should be at the back)
    # To achieve this we loop over the decodes in descending order and
    # the prefills in ascending order. We swap decodes from the  "back"
    # i.e. past where the last decode should be in the reodorered with
    # prefills from the front of the batch.
    # `decodes` and `prefills` are already in ascending order just based on
    # the above loop
    num_decodes = len(decodes)
    num_prefills = len(prefills)
    modified_batch = False

    for i in range(1, min(num_decodes, num_prefills) + 1):
        # If the decode is at the "back" of the batch, i, we can swap it
        # with the prefill closest to the front of the batch
        decode_idx = decodes[num_decodes - i]
        if decode_idx < num_decodes:
            break

        input_batch.swap_states(prefills[i - 1], decode_idx)
        modified_batch = True

    return modified_batch
867
868


869
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
870
871
872
873
874
875
876
877
878
    """
    Reshapes the query tensor for the specified batch size, so that
    it has shape (batch_size, seq_len, num_heads, head_dim).
    """
    assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
    total_tokens = query.shape[0]
    num_heads = query.shape[1]
    head_dim = query.shape[2]
    assert total_tokens % batch_size == 0, (
879
880
        f"{total_tokens=} is not divisible by {batch_size=}"
    )
881
882
883
884
    seq_len = total_tokens // batch_size
    return query.view(batch_size, seq_len, num_heads, head_dim)


885
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
886
887
888
889
890
891
892
    """
    Reshapes the attention output tensor, so that
    the batch_size and seq_len dimensions are combined.
    """
    if attn_output.dim() == 3:
        # Already in the correct shape
        return attn_output
893
    assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D"
894
    total_tokens = attn_output.shape[0] * attn_output.shape[1]
895
    return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
896
897


898
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
899
900
    ("logits_indices_padded", Optional[torch.Tensor], None),
    ("num_logits_indices", int, 0),
901
902
903
904
905
906
907
908
909
910
911
912
]


def subclass_attention_metadata(
    name_prefix: str,
    metadata_cls: Any,
    fields: list[tuple[str, Any, Any]],
) -> Any:
    """
    Return a new subclass of `metadata_cls` with additional fields
    """
    name: str = name_prefix + metadata_cls.__name__  # type: ignore
913
    Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
914
915
916
    return Wrapped


917
918
919
920
921
922
923
924
925
926
927
928
929
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
    logits_indices_padded: torch.Tensor
    num_logits_indices: int


def create_fast_prefill_custom_backend(
    prefix: str,
    underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
    underlying_builder = underlying_attn_backend.get_builder_cls()

    class FastPrefillAttentionBuilder(underlying_builder):  # type: ignore
930
931
932
933
934
935
936
937
938
939
940
941
        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
            new_common_attn_metadata = (
                make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
            )
            metadata = super().build(
                common_prefix_len, new_common_attn_metadata, fast_build
            )
942
943

            class KVSharingFastPrefillAttentionMetadata(
944
945
946
                metadata.__class__,  #  type: ignore
                KVSharingFastPrefillMetadata,
            ):
947
948
949
                def __init__(self, metadata, common_attn_metadata):
                    # Shallow copy all fields in metadata cls
                    for field in fields(metadata.__class__):
950
                        setattr(self, field.name, getattr(metadata, field.name))
951
952

                    # Set additional fields that will be used in model code
953
954
955
956
957
                    assert (
                        common_attn_metadata.logits_indices_padded is not None
                        and common_attn_metadata.num_logits_indices is not None
                    )
                    self.logits_indices_padded = (
958
                        common_attn_metadata.logits_indices_padded
959
960
                    )
                    self.num_logits_indices = common_attn_metadata.num_logits_indices
961

962
            return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata)
963
964
965
966

    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
967
968
        builder_cls=FastPrefillAttentionBuilder,
    )
969
970

    return attn_backend
971
972
973
974


def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
    # Needed for causal_conv1d
975
    seqlens = query_start_loc_p.diff().to("cpu")
976
977
978
    nums_dict = {}  # type: ignore
    batch_ptr = None
    token_chunk_offset_ptr = None
979
    device = query_start_loc_p.device
980
981
982
    for BLOCK_M in [8]:  # cover all BLOCK_M values
        nums = -(-seqlens // BLOCK_M)
        nums_dict[BLOCK_M] = {}
983
984
        nums_dict[BLOCK_M]["nums"] = nums
        nums_dict[BLOCK_M]["tot"] = nums.sum().item()
985
        mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
986
987
988
        nums_dict[BLOCK_M]["mlist"] = mlist
        mlist_len = len(nums_dict[BLOCK_M]["mlist"])
        nums_dict[BLOCK_M]["mlist_len"] = mlist_len
989
990
991
992
993
        MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
        offsetlist = []  # type: ignore
        for idx, num in enumerate(nums):
            offsetlist.extend(range(num))
        offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
994
        nums_dict[BLOCK_M]["offsetlist"] = offsetlist
995
996
997

        if batch_ptr is None:
            # Update default value after class definition
998
999
1000
1001
1002
1003
            batch_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
            token_chunk_offset_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
1004
1005
1006
1007
        else:
            if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
                batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
                token_chunk_offset_ptr.resize_(  # type: ignore
1008
1009
                    MAX_NUM_PROGRAMS
                ).fill_(PAD_SLOT_ID)
1010
1011
1012

        batch_ptr[0:mlist_len].copy_(mlist)
        token_chunk_offset_ptr[  # type: ignore
1013
1014
1015
1016
            0:mlist_len
        ].copy_(offsetlist)
        nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
        nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr  # type: ignore
1017
1018

    return nums_dict, batch_ptr, token_chunk_offset_ptr