utils.py 41 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import abc
4
import enum
5
import functools
6
from abc import abstractmethod
7
from dataclasses import dataclass, field, fields, make_dataclass
8
9
10
11
12
13
14
15
16
17
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Generic,
    Literal,
    Protocol,
    TypeVar,
    get_args,
)
18

19
import numpy as np
20
import torch
21
from typing_extensions import runtime_checkable
22

23
from vllm.config import VllmConfig, get_layers_from_vllm_config
24
from vllm.utils.math_utils import cdiv
25

26
if TYPE_CHECKING:
27
    from vllm.attention.backends.abstract import AttentionImpl
28
29
30
    from vllm.v1.core.sched.output import SchedulerOutput
    from vllm.v1.worker.gpu_input_batch import InputBatch

31
import vllm.envs as envs
32
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
33
from vllm.distributed.kv_transfer.kv_connector.utils import (
34
35
    get_kv_connector_cache_layout,
)
36
from vllm.logger import init_logger
37
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
38
from vllm.v1.kv_cache_interface import AttentionSpec
39
from vllm.v1.worker.ubatch_utils import UBatchSlice
40
41

logger = init_logger(__name__)
42
KVCacheLayoutType = Literal["NHD", "HND"]
43
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
44

45
46
PAD_SLOT_ID = -1

47
48
49

def is_valid_kv_cache_layout(value: str) -> bool:
    return value in get_args(KVCacheLayoutType)
50

51
52
53
54

@dataclass
class CommonAttentionMetadata:
    """
55
56
    Per-batch attention metadata, shared across layers and backends.
    AttentionMetadataBuilder instances use it to construct per-layer metadata.
57

58
    For many of the tensors we keep both GPU and CPU versions.
59
60
61
    """

    query_start_loc: torch.Tensor
62
    query_start_loc_cpu: torch.Tensor
63
    """(batch_size + 1,), the start location of each request in query Tensor"""
64

65
    seq_lens: torch.Tensor
66
    seq_lens_cpu: torch.Tensor
67
68
    """(batch_size,), the length of each request including both computed tokens
    and newly scheduled tokens"""
69

70
71
72
    num_computed_tokens_cpu: torch.Tensor
    """(batch_size,), the number of computed tokens for each request"""

73
74
75
76
77
78
    num_reqs: int
    """Number of requests"""
    num_actual_tokens: int
    """Total number of tokens in batch"""
    max_query_len: int
    """Longest query in batch"""
79
80
    max_seq_len: int
    """Longest context length in batch"""
81

82
83
84
    block_table_tensor: torch.Tensor
    slot_mapping: torch.Tensor

85
86
    causal: bool = True

87
    # Needed by FastPrefillAttentionBuilder
88
89
    logits_indices_padded: torch.Tensor | None = None
    num_logits_indices: int | None = None
90

91
    # Needed by CrossAttentionBuilder
92
    encoder_seq_lens: np.ndarray | None = None
93

94
    dcp_local_seq_lens: torch.Tensor | None = None
95
96
    """Sequence lengths of the local rank in decode context parallelism world"""

97

98
99
100
101
102
def slice_query_start_locs(
    query_start_loc: torch.Tensor,
    request_slice: slice,
) -> torch.Tensor:
    """
103
    Creates a new query_start_loc that corresponds to the requests in
104
105
106
107
108
    request_slice.

    Note: This function creates a new tensor to hold the new query_start_locs.
    This will break cudagraph compatibility.
    """
109
110
111
112
    return (
        query_start_loc[request_slice.start : request_slice.stop + 1]
        - query_start_loc[request_slice.start]
    )
113
114
115


def _make_metadata_with_slice(
116
117
    ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
118
    """
119
    This function creates a new CommonAttentionMetadata that corresponds to
120
121
122
    the requests included in ubatch_slice
    """

123
    assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
124

125
126
127
    request_slice = ubatch_slice.request_slice
    token_slice = ubatch_slice.token_slice

128
129
130
131
132
133
    start_locs = attn_metadata.query_start_loc_cpu
    first_req = request_slice.start
    first_tok = token_slice.start
    last_req = request_slice.stop - 1
    last_tok = token_slice.stop - 1

134
    assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
135
        "Token slice start outside of first request"
136
137
    )
    assert start_locs[last_req] <= last_tok < start_locs[last_req + 1], (
138
        "Token slice end outside of last request"
139
    )
140
141
142
143
144
145
146
147
148

    # If the "middle" request has tokens in both ubatches, we have to split it.
    # If ubatch_slice is the first ubatch then we will be splitting the last
    # request. If it's the second microbatch, then we will be splitting the
    # first request
    splits_first_request = first_tok > start_locs[first_req]
    splits_last_request = last_tok < start_locs[last_req + 1] - 1

    query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
149
150
151
    query_start_loc = slice_query_start_locs(
        attn_metadata.query_start_loc, request_slice
    )
152

153
    assert len(query_start_loc) >= 2, (
154
155
        f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
    )
156

157
158
159
160
    if splits_first_request:
        tokens_skipped = first_tok - start_locs[first_req]
        query_start_loc[1:] -= tokens_skipped
        query_start_loc_cpu[1:] -= tokens_skipped
161
162
    seq_lens = attn_metadata.seq_lens[request_slice]
    seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
163
164
165
166
167
168
169
170
171
172
173
174
175

    if splits_last_request:
        tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
        query_start_loc[-1] -= tokens_skipped
        query_start_loc_cpu[-1] -= tokens_skipped

        # Make sure we don't modify the seq_lens tensors
        #  (not cudagraph compatible)
        seq_lens = seq_lens.clone()
        seq_lens_cpu = seq_lens_cpu.clone()
        seq_lens[-1] -= tokens_skipped
        seq_lens_cpu[-1] -= tokens_skipped

176
    max_seq_len = int(seq_lens_cpu.max())
177
    num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
178
179
180
181

    num_requests = request_slice.stop - request_slice.start
    num_actual_tokens = token_slice.stop - token_slice.start
    max_query_len = int(
182
183
        torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
    )
184

185
186
187
188
189
    # This is to account for the case where we are in a dummy
    # run and query_start_loc_cpu is full of 0s
    if max_query_len == 0:
        max_query_len = attn_metadata.max_query_len

190
191
192
193
194
195
196
197
198
199
200
201
    block_table_tensor = attn_metadata.block_table_tensor[request_slice]
    slot_mapping = attn_metadata.slot_mapping[token_slice]

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens_cpu,
        num_computed_tokens_cpu=num_computed_tokens_cpu,
        num_reqs=num_requests,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
202
        max_seq_len=max_seq_len,
203
204
205
206
207
208
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
    )


def split_attn_metadata(
209
    ubatch_slices: list[UBatchSlice],
210
211
212
    common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
    """
213
    Creates a new CommonAttentionMetadata instance that corresponds to the
214
    requests for each UBatchSlice in ubatch_slices.
215
216
217
218
219

    Note: This function does not modify common_attn_metadata
    """
    results = []
    for ubatch_slice in ubatch_slices:
220
        results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
221

222
223
224
    return results


225
226
227
M = TypeVar("M")


228
class AttentionCGSupport(enum.Enum):
229
    """Constants for the cudagraph support of the attention backend
230
231
232
    Here we do not consider the cascade attention, as currently
    it is never cudagraph supported."""

233
234
235
236
    ALWAYS = 3
    """Cudagraph always supported; supports mixed-prefill-decode"""
    UNIFORM_BATCH = 2
    """Cudagraph supported for batches the only contain query lengths that are
237
    the same, this can be used for spec-decode
238
239
240
        i.e. "decodes" are 1 + num_speculative_tokens"""
    UNIFORM_SINGLE_TOKEN_DECODE = 1
    """Cudagraph supported for batches the only contain query_len==1 decodes"""
241
242
243
244
    NEVER = 0
    """NO cudagraph support"""


245
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
246
    # Does this backend/builder support CUDA Graphs for attention (default: no).
247
248
    # Do not access directly. Call get_cudagraph_support() instead.
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
249
250
251
    # Does this backend/builder reorder the batch?
    # If not, set this to None. Otherwise set it to the query
    # length that will be pulled into the front of the batch.
252
    reorder_batch_threshold: int | None = None
253
254

    @abstractmethod
255
256
257
258
259
260
261
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
262
        self.kv_cache_spec = kv_cache_spec
263
264
265
        self.layer_names = layer_names
        self.vllm_config = vllm_config
        self.device = device
266

267
268
269
270
271
272
273
274
275
    @classmethod
    def get_cudagraph_support(
        cls: type["AttentionMetadataBuilder"],
        vllm_config: VllmConfig,
        kv_cache_spec: AttentionSpec,
    ) -> AttentionCGSupport:
        """Get the cudagraph support level of this builder class."""
        return cls._cudagraph_support

276
    def _init_reorder_batch_threshold(
277
        self,
278
        reorder_batch_threshold: int | None = 1,
279
280
        supports_spec_as_decode: bool = False,
        supports_dcp_with_varlen: bool = False,
281
    ) -> None:
282
        self.reorder_batch_threshold = reorder_batch_threshold
283
        if self.reorder_batch_threshold is not None and supports_spec_as_decode:
284
285
286
287
            # If the backend supports spec-as-decode kernels, then we can set
            # the reorder_batch_threshold based on the number of speculative
            # tokens from the config.
            speculative_config = self.vllm_config.speculative_config
288
289
290
291
            if (
                speculative_config is not None
                and speculative_config.num_speculative_tokens is not None
            ):
292
293
294
                self.reorder_batch_threshold = max(
                    self.reorder_batch_threshold,
                    1 + speculative_config.num_speculative_tokens,
295
                )
296

297
298
299
300
301
302
        if (
            self.vllm_config.parallel_config.decode_context_parallel_size > 1
            and not supports_dcp_with_varlen
        ):
            self.reorder_batch_threshold = 1

303
    @abstractmethod
304
305
306
307
308
309
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> M:
310
311
312
        """
        Central method that builds attention metadata.
        Some builders (MLA) require reorder_batch to be called prior to build.
313

314
315
316
317
318
319
        Args:
            common_prefix_len: The length of the common prefix of the batch.
            common_attn_metadata: The common attention metadata.
            fast_build: The meta-data will prioritize speed of building over
                then speed at execution. Can be used for spec-decode where the
                result of a build call may only be used for few layers/iters.
320
321
322
323
        """
        raise NotImplementedError

    def build_for_cudagraph_capture(
324
325
        self, common_attn_metadata: CommonAttentionMetadata
    ) -> M:
326
327
328
329
330
        """
        Build attention metadata for CUDA graph capture. Uses build by default.
        Subclasses that override this method should call self.build or
        super().build_for_cudagraph_capture.
        """
331
332
333
        return self.build(
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
334

335
336
337
338
339
340
341
    def build_for_drafting(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        draft_index: int,
    ) -> M:
        """
        Build attention metadata for draft model. Uses build by default.
342

343
344
345
346
347
348
349
350
        Args:
            common_attn_metadata: The common attention metadata.
            draft_index: The index of the current draft operation.
                When speculating a chain of tokens, this index refers to the
                draft attempt for the i-th token.
                For tree-based attention, this index instead refers to the
                draft attempt for the i-th level in the tree of tokens.
        """
351
352
353
354
355
        return self.build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
            fast_build=True,
        )
356

357
358
359
360
361
362
363
364
    def use_cascade_attention(
        self,
        common_prefix_len: int,
        query_lens: np.ndarray,
        num_query_heads: int,
        num_kv_heads: int,
        use_alibi: bool,
        use_sliding_window: bool,
365
        use_local_attention: bool,
366
        num_sms: int,
367
        dcp_world_size: int,
368
369
370
    ) -> bool:
        return False

371

372
373
@functools.lru_cache
def get_kv_cache_layout():
374
    # Format specified by the code.
375
    global _KV_CACHE_LAYOUT_OVERRIDE
376
377
378

    if _KV_CACHE_LAYOUT_OVERRIDE is not None:
        cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
379
380
381
382
383
        logger.info_once(
            "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
            "Setting KV cache layout to %s.",
            cache_layout,
        )
384
385
386
        return cache_layout

    # Format specified by the user.
387
    cache_layout = envs.VLLM_KV_CACHE_LAYOUT
388
    # When neither the user nor the override specified a layout, get default
389
    if cache_layout is None:
390
        cache_layout = get_kv_connector_cache_layout()
391
    else:
392
        assert is_valid_kv_cache_layout(cache_layout)
393
394
395
396
397
        logger.info_once(
            "`VLLM_KV_CACHE_LAYOUT` environment variable "
            "detected. Setting KV cache layout to %s.",
            cache_layout,
        )
398
    return cache_layout
399
400


401
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
402
403
404
405
    global _KV_CACHE_LAYOUT_OVERRIDE
    _KV_CACHE_LAYOUT_OVERRIDE = cache_layout


406
407
408
409
@dataclass
class PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
410
411
412
    the same values for the following hyperparameters. Should not be used for
    trtllm-gen backend since it supports different values for the following
    hyperparameters.
413
414
415
    """

    window_left: int
416
    logits_soft_cap: float | None
417
    sm_scale: float
418
    has_sinks: bool = False
419
    # has same params for all layers
420
421
    has_same_window_lefts: bool | None = field(default=None, compare=False)
    has_same_all_params: bool | None = field(default=None, compare=False)
422
423
424


def get_per_layer_parameters(
425
426
    vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
) -> dict[str, PerLayerParameters]:
427
    """
428
    Scan layers in `layer_names` and determine some hyperparameters
429
430
431
    to use during `plan`.
    """

432
    layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
433
434
435
436
437
438
439
440
441
442
443
    per_layer_params: dict[str, PerLayerParameters] = {}

    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, cls_)

        # Infer hyperparameters from the attention layer
        window_size = getattr(impl, "sliding_window", None)
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = getattr(impl, "logits_soft_cap", None)
        sm_scale = impl.scale
444
        has_sinks = getattr(impl, "sinks", None) is not None
445

446
447
448
        per_layer_params[key] = PerLayerParameters(
            window_left, logits_soft_cap, sm_scale, has_sinks
        )
449
450
451
452
453

    return per_layer_params


def infer_global_hyperparameters(
454
455
    per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
456
    """
457
    Currently, FlashInfer backend other than trtllm-gen
458
    only support models in which all layers share
459
460
461
462
463
464
465
466
467
468
469
470
471
    the same values for the following hyperparameters:
    - `window_left`
    - `logits_soft_cap`
    - `sm_scale`

    So this function asserts that all layers share the same values for these
    hyperparameters and returns the global values.
    """

    assert len(per_layer_params) > 0, "No attention layers found in the model."

    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]
472

473
474
475
476
477
478
    global_params.has_same_window_lefts = all(
        params.window_left == global_params.window_left for params in param_sets
    )
    global_params.has_same_all_params = all(
        params == global_params for params in param_sets
    )
479
480
481
482

    return global_params


483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
#   q_seqlens  = [4, 10, 5]
#   kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
#  for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
#   batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 | 1 1 1 1 1
#               3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
#  attention mask like:
#   batch idx: 0  (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 |         1
#               3 |         1 1
#
# We can simulate this mask using standard flash-attention by breaking the
#  sequences into local ("virtual") batches, where each local batch item is a
#  local attention block, so in this case batch idx 0 would be broken up into:
#
#   local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4)  (batch 0)
#        k_toks >   0 1 2 3
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#   local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
#        k_toks >   4 5
#        q_toks v  _____________
#               2 | 1
#               3 | 1 1
#
# e.g. if we have:
#   attn_chunk_size = 4
#   query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
#                           __b0__  ______b1______  __b2__ < orig batch indices
#   q_seqlens_local    = [   2,  2,  1,  4,  4,  1,  4,  1]
#   cu_seqlens_q_local = [0, 4,  6, 10, 14, 18, 19, 23, 24]
#   seqlens_k_local    = [   4,  2,  4,  4,  4,  1,  4,  1]
#   block_table_local  : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
    attn_chunk_size: int,
537
    common_attn_metadata: CommonAttentionMetadata,
538
    block_size: int = 0,
539
540
541
542
543
544
) -> CommonAttentionMetadata:
    query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
    seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
    block_table = common_attn_metadata.block_table_tensor
    device = common_attn_metadata.query_start_loc.device

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
    actual_batch_size = seq_lens_np.shape[0]

    # Handle if we are starting in the middle of a local attention block,
    #  we assume q_seqlens > 0 (for all elements), for each batch idx we compute
    #  the number of tokens that are not in the first local attention block and
    #  then we can simply use a cdiv for the rest.
    # For example if we have:
    #   attn_chunk_size = 4
    #   q_seqlens = [4, 10, 5]
    #   k_seqlens = [6, 17, 9]
    # Then we would get:
    #   new_tokens_in_first_block = [2, 1, 4]
    #   local_blocks = [2, 4, 2]
    q_tokens_in_first_block = np.minimum(
560
561
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
    ).astype(np.int32)
562
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
563
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583

    # Once we know the number of local blocks we can compute the request spans
    #  for each batch idx, we can figure out the number of "virtual" requests we
    #  have to make,
    # For the above example we would get:
    #   seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
    #
    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
    #   (TODO: max a utility to share this code with _prepare_inputs)
    # arange step 1. [2, 4, 2] -> [2, 6, 8]
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
    # Then we can compute the seqlens_q_local, handling the fact that the
    #  first and last blocks could be partial
584
    seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
585
586
587
588
    # set the first block since this may be a partial block
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    # set the remaining blocks
    seqlens_q_local[arange > 0] = np.minimum(
589
590
        seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
    )[arange > 0]
591
592

    # convert from q_seqlens to cu_seqlens_q
593
594
595
    cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
    np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
    cu_seqlens_q_local[0] = 0
596
597
598
599
600
601

    # compute the seqlens_k_local,
    #  basically a full local attention block for all but the last block in each
    #  batch
    # For our example this will be:
    #   seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
602
    seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
603
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
604
    num_computed_tokens_local = seqlens_k_local - seqlens_q_local
605

606
607
608
    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
        rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
    )
609
610
611
612
    # For the example the local attention blocks start at:
    #                           _b0_  _____b1_____  _b2_
    #   k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
    block_starts = k_seqstarts_absolute // block_size
613
614
615
    assert attn_chunk_size % block_size == 0, (
        f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}"
    )
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    pages_per_local_batch = attn_chunk_size // block_size

    # Create a block_table for the local attention blocks
    # For out example if we have a block-table like (assuming block_size=2):
    #   block_table = [
    #     [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],  < batch 0
    #     [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  < batch 1
    #     [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],  < batch 2
    #   ]
    # Then for the local batches we would want a block-table like
    #   block_table_local = [
    #     [  0,  1 ], < local-batch 0, (batch 0, starting from k[0])
    #     [  2,  3 ], < local-batch 1, (batch 0, starting from k[4])
    #     [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
    #     [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
    #     [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
    #     [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
    #     [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
    #     [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
    #   ]
636
637
638
639
640
641
642
643
    block_indices = block_starts[:, None] + np.arange(
        pages_per_local_batch, dtype=np.int32
    )
    block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
    batch_indices = np.repeat(
        np.arange(actual_batch_size, dtype=np.int32),
        local_blocks * pages_per_local_batch,
    )
644
645
646
647
648
649
650

    # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
    # regression when using numpy arrays (batch and block indices) to index into
    # torch tensor (block_table). As a workaround, convert numpy arrays to torch
    # tensor first, which recovers perf.
    batch_indices_torch = torch.from_numpy(batch_indices)
    block_indices_torch = torch.from_numpy(block_indices)
651
652
653
    block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
        virtual_batches, -1
    )
654

655
656
    query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
    seq_lens_cpu = torch.from_numpy(seqlens_k_local)
657
    max_seq_len = int(seq_lens_cpu.max())
658
659
660

    return CommonAttentionMetadata(
        query_start_loc_cpu=query_start_loc_cpu,
661
        query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
662
663
664
665
666
667
        seq_lens_cpu=seq_lens_cpu,
        seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
        num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
        num_reqs=len(seq_lens_cpu),
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        max_query_len=seqlens_q_local.max(),
668
        max_seq_len=max_seq_len,
669
670
        block_table_tensor=block_table_local,
        slot_mapping=common_attn_metadata.slot_mapping,
671
        causal=True,
672
    )
673
674


675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
def make_kv_sharing_fast_prefill_common_attn_metadata(
    common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
    if common_attn_metadata.max_query_len == 1:
        # All requests are decode (assume 1 token for now)
        # Skip computing fast prefill path
        return common_attn_metadata

    assert common_attn_metadata.logits_indices_padded is not None
    assert common_attn_metadata.num_logits_indices is not None

    logits_indices_padded = common_attn_metadata.logits_indices_padded
    num_logits_indices = common_attn_metadata.num_logits_indices
    # Get rid of CUDAGraph padding, if any
    logits_indices = logits_indices_padded[:num_logits_indices]
    num_reqs = common_attn_metadata.num_reqs
    query_start_loc = common_attn_metadata.query_start_loc
    seq_lens = common_attn_metadata.seq_lens
    # Example inputs
    # num_reqs: 3
    # generation_indices:  [14, 18, 19, 27]
    # query_start_loc: [0, 15, 20, 28]
    # seq_lens:        [41, 31, 40]

    # Find how many decode indices belong to each request
    # request_ids: [0, 1, 1, 2]
701
    request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
702
703
704
705
706
707
708

    # Figure out how many tokens are in each request
    # num_decode_tokens: [1, 2, 1]
    num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)

    # Calculate new query_start_loc with tokens in generation_indices
    # decode_query_start_loc: [0, 1, 3, 4]
709
710
711
    decode_query_start_loc = torch.empty(
        num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype
    )
712
713
714
715
716
717
718
719

    decode_query_start_loc[0] = 0
    decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
    decode_max_query_len = int(num_decode_tokens.max().item())
    total_num_decode_tokens = int(num_decode_tokens.sum().item())

    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=decode_query_start_loc,
720
        query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
        num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
        num_reqs=num_reqs,
        num_actual_tokens=total_num_decode_tokens,
        max_query_len=decode_max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        block_table_tensor=common_attn_metadata.block_table_tensor,
        slot_mapping=common_attn_metadata.slot_mapping,
        causal=True,
    )
    return common_attn_metadata


735
def subclass_attention_backend(
736
737
738
    name_prefix: str,
    attention_backend_cls: type[AttentionBackend],
    builder_cls: type[AttentionMetadataBuilder[M]],
739
740
741
742
743
744
) -> type[AttentionBackend]:
    """
    Return a new subclass where `get_builder_cls` returns `builder_cls`.
    """
    name: str = name_prefix + attention_backend_cls.__name__  # type: ignore

745
746
747
    return type(
        name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
    )
748
749


750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
def split_decodes_prefills_and_extends(
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
) -> tuple[int, int, int, int, int, int]:
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.

    Returns:
        num_decodes: The number of decode requests.
        num_extends: The number of extend requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_extend_tokens: The number of tokens in the extend requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu
    seq_lens = common_attn_metadata.seq_lens_cpu

    if max_query_len <= decode_threshold:
        return num_reqs, 0, 0, num_tokens, 0, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
    is_prefill_or_extend = query_lens > decode_threshold
    is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
    first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
    first_prefill = is_prefill.int().argmax(dim=-1).item()
    num_decodes = first_extend
    num_decode_tokens = query_start_loc[first_extend].item()
    if not torch.any(is_prefill_or_extend):
        return (num_decodes, 0, 0, num_decode_tokens, 0, 0)

    num_prefills_or_extends = num_reqs - num_decodes
    num_prefill_or_extend_tokens = num_tokens - num_decode_tokens
    if not torch.any(is_prefill):
        return (
            num_decodes,
            num_prefills_or_extends,
            0,
            num_decode_tokens,
            num_prefill_or_extend_tokens,
            0,
        )

    num_extends = first_prefill - num_decodes
    num_prefills = num_reqs - first_prefill

    num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
    num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens
    return (
        num_decodes,
        num_extends,
        num_prefills,
        num_decode_tokens,
        num_extend_tokens,
        num_prefill_tokens,
    )


817
def split_decodes_and_prefills(
818
819
820
821
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
    require_uniform: bool = False,
) -> tuple[int, int, int, int]:
822
823
824
825
826
827
828
829
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.
830
831
832
        require_uniform: If True, requires that all decode requests have the
            same query length. When set, some queries may be considered prefills
            even if they are <= decode_threshold, in order to ensure uniformity.
833
834
835
836
837
838
839
840
841
842
843
844

    Returns:
        num_decodes: The number of decode requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu

845
846
847
    if max_query_len <= decode_threshold and (
        not require_uniform or decode_threshold <= 1
    ):
848
849
850
        return num_reqs, 0, num_tokens, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
851
852
853
854
855
856
857
858
859
    if query_lens[0].item() > decode_threshold:
        # first request is not decode, so no decode requests
        return 0, num_reqs, 0, num_tokens

    if require_uniform:
        is_prefill = query_lens != query_lens[0]
    else:
        is_prefill = query_lens > decode_threshold

860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
    if not torch.any(is_prefill):
        return num_reqs, 0, num_tokens, 0

    first_prefill = is_prefill.int().argmax(dim=-1).item()
    assert torch.all(query_lens[:first_prefill] <= decode_threshold)
    num_decodes = first_prefill
    num_prefills = num_reqs - num_decodes
    num_decode_tokens = query_start_loc[first_prefill].item()
    num_prefill_tokens = num_tokens - num_decode_tokens
    return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)


def reorder_batch_to_split_decodes_and_prefills(
    input_batch: "InputBatch",
    scheduler_output: "SchedulerOutput",
    decode_threshold: int = 1,
) -> bool:
    """
    Reorders the batch to split into prefill and decode requests; places all
    requests with <= decode_threshold tokens at the front of the batch.
880

881
882
883
    Returns:
        True if the batch was modified, False otherwise.
    """
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    # We now want to reorder the batch into decode → extend → prefill order
    # where:
    #   decode: request with num_scheduled_tokens <= decode_threshold
    #   extend: non-decode request with existing context
    #   prefill: non-decode request with no existing context
    # NOTE for now we loosely use "decode" to mean requests where attention is
    #  likely memory-bound and "prefill" to mean requests where attention is
    #  likely compute-bound,
    num_reqs = len(input_batch.req_ids)
    num_scheduled_tokens = [
        scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
    ]
    num_scheduled_tokens_np = np.array(num_scheduled_tokens)
    num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]

    is_decode = num_scheduled_tokens_np <= decode_threshold
900
901
    is_extend = (~is_decode) & (num_computed_tokens_np > 0)
    is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920

    # Desired order: decode → extend → prefill
    req_regions = np.zeros(is_decode.shape, dtype=np.int32)  # 0 = decode by default
    req_regions[is_extend] = 1
    req_regions[is_prefill] = 2

    num_decodes = int(is_decode.sum())
    num_extends = int(is_extend.sum())

    target_regions = np.zeros(num_reqs, dtype=np.int32)
    target_regions[num_decodes : num_decodes + num_extends] = 1
    target_regions[num_decodes + num_extends :] = 2

    needs_swap = req_regions != target_regions

    if not needs_swap.any():
        return False

    # Extract indices that need swapping and sort by target region
921
    orig_indices = np.where(needs_swap)[0]
922
    sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
923
    src_indices = orig_indices[sorted_order]
924

925
    src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
926
927
928
929
930
931
932
933
934
935
936

    for src in src_dest_map:
        dst = src_dest_map[src]
        while src != dst:
            input_batch.swap_states(src, dst)
            # Mark dst as done by updating its destination to itself
            next_dst = src_dest_map.get(dst, dst)
            src_dest_map[dst] = dst
            dst = next_dst

    return True
937
938


939
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
940
941
942
943
944
945
946
947
948
    """
    Reshapes the query tensor for the specified batch size, so that
    it has shape (batch_size, seq_len, num_heads, head_dim).
    """
    assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
    total_tokens = query.shape[0]
    num_heads = query.shape[1]
    head_dim = query.shape[2]
    assert total_tokens % batch_size == 0, (
949
950
        f"{total_tokens=} is not divisible by {batch_size=}"
    )
951
952
953
954
    seq_len = total_tokens // batch_size
    return query.view(batch_size, seq_len, num_heads, head_dim)


955
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
956
957
958
959
960
961
962
    """
    Reshapes the attention output tensor, so that
    the batch_size and seq_len dimensions are combined.
    """
    if attn_output.dim() == 3:
        # Already in the correct shape
        return attn_output
963
    assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D"
964
    total_tokens = attn_output.shape[0] * attn_output.shape[1]
965
    return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
966
967


968
969
970
971
972
973
974
975
976
def subclass_attention_metadata(
    name_prefix: str,
    metadata_cls: Any,
    fields: list[tuple[str, Any, Any]],
) -> Any:
    """
    Return a new subclass of `metadata_cls` with additional fields
    """
    name: str = name_prefix + metadata_cls.__name__  # type: ignore
977
    Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
978
979
980
    return Wrapped


981
982
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
983
984
    logits_indices_padded: torch.Tensor | None = None
    num_logits_indices: int | None = None
985
986
987
988
989
990
991
992
993


def create_fast_prefill_custom_backend(
    prefix: str,
    underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
    underlying_builder = underlying_attn_backend.get_builder_cls()

    class FastPrefillAttentionBuilder(underlying_builder):  # type: ignore
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
            new_common_attn_metadata = (
                make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
            )
            metadata = super().build(
                common_prefix_len, new_common_attn_metadata, fast_build
            )
1006
1007

            class KVSharingFastPrefillAttentionMetadata(
1008
1009
1010
                metadata.__class__,  #  type: ignore
                KVSharingFastPrefillMetadata,
            ):
1011
1012
                def __init__(self, metadata, common_attn_metadata):
                    # Shallow copy all fields in metadata cls
1013
1014
                    for _field in fields(metadata.__class__):
                        setattr(self, _field.name, getattr(metadata, _field.name))
1015

1016
                    self.logits_indices_padded = (
1017
                        common_attn_metadata.logits_indices_padded
1018
1019
                    )
                    self.num_logits_indices = common_attn_metadata.num_logits_indices
1020

1021
            return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata)
1022
1023
1024
1025

    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
1026
1027
        builder_cls=FastPrefillAttentionBuilder,
    )
1028
1029

    return attn_backend
1030
1031
1032
1033


def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
    # Needed for causal_conv1d
1034
    seqlens = query_start_loc_p.diff().to("cpu")
1035
1036
1037
    nums_dict = {}  # type: ignore
    batch_ptr = None
    token_chunk_offset_ptr = None
1038
    device = query_start_loc_p.device
1039
1040
1041
    for BLOCK_M in [8]:  # cover all BLOCK_M values
        nums = -(-seqlens // BLOCK_M)
        nums_dict[BLOCK_M] = {}
1042
1043
        nums_dict[BLOCK_M]["nums"] = nums
        nums_dict[BLOCK_M]["tot"] = nums.sum().item()
1044
        mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
1045
1046
1047
        nums_dict[BLOCK_M]["mlist"] = mlist
        mlist_len = len(nums_dict[BLOCK_M]["mlist"])
        nums_dict[BLOCK_M]["mlist_len"] = mlist_len
1048
1049
1050
1051
1052
        MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
        offsetlist = []  # type: ignore
        for idx, num in enumerate(nums):
            offsetlist.extend(range(num))
        offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
1053
        nums_dict[BLOCK_M]["offsetlist"] = offsetlist
1054
1055
1056

        if batch_ptr is None:
            # Update default value after class definition
1057
1058
1059
1060
1061
1062
            batch_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
            token_chunk_offset_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
1063
1064
1065
1066
        else:
            if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
                batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
                token_chunk_offset_ptr.resize_(  # type: ignore
1067
1068
                    MAX_NUM_PROGRAMS
                ).fill_(PAD_SLOT_ID)
1069
1070
1071

        batch_ptr[0:mlist_len].copy_(mlist)
        token_chunk_offset_ptr[  # type: ignore
1072
1073
1074
1075
            0:mlist_len
        ].copy_(offsetlist)
        nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
        nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr  # type: ignore
1076
1077

    return nums_dict, batch_ptr, token_chunk_offset_ptr
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115


def get_dcp_local_seq_lens(
    seq_lens: torch.Tensor,
    dcp_world_size: int = 1,
    dcp_rank: int | None = None,
    dcp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
    """While using dcp, kv_cache size stored on each rank may be different,
    use this function to calculate split decode seq_lens of each dcp rank.
    Only consider dcp now, we can extend the case of cp based on this.
    """
    num_requests = seq_lens.size(0)
    if dcp_rank is None:
        rank_offsets = (
            torch.arange(dcp_world_size, dtype=torch.int32)
            .unsqueeze(0)
            .repeat(num_requests, 1)
        )
    else:
        rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
    seq_lens_tiled = (
        seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
    )
    base = (
        seq_lens_tiled
        // dcp_kv_cache_interleave_size
        // dcp_world_size
        * dcp_kv_cache_interleave_size
    )
    remainder = seq_lens_tiled - base * dcp_world_size
    remainder = torch.clip(
        remainder - rank_offsets * dcp_kv_cache_interleave_size,
        0,
        dcp_kv_cache_interleave_size,
    )
    dcp_local_seq_lens = base + remainder
    return dcp_local_seq_lens.squeeze(1)