utils.py 36.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import abc
4
import enum
5
import functools
6
from abc import abstractmethod
7
from dataclasses import dataclass, field, fields, make_dataclass
8
9
10
11
12
13
14
15
16
17
18
19
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Generic,
    Literal,
    Optional,
    Protocol,
    TypeVar,
    Union,
    get_args,
)
20

21
import numpy as np
22
import torch
23
from typing_extensions import runtime_checkable
24

25
from vllm.config import VllmConfig, get_layers_from_vllm_config
26
27
from vllm.utils import cdiv

28
if TYPE_CHECKING:
29
    from vllm.attention.backends.abstract import AttentionImpl
30
31
32
    from vllm.v1.core.sched.output import SchedulerOutput
    from vllm.v1.worker.gpu_input_batch import InputBatch

33
import vllm.envs as envs
34
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
35
from vllm.distributed.kv_transfer.kv_connector.utils import (
36
37
    get_kv_connector_cache_layout,
)
38
from vllm.logger import init_logger
39
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
40
from vllm.v1.kv_cache_interface import AttentionSpec
41
from vllm.v1.worker.ubatch_utils import UBatchSlice
42
43

logger = init_logger(__name__)
44
45
46
KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None

47
48
PAD_SLOT_ID = -1

49
50
51

def is_valid_kv_cache_layout(value: str) -> bool:
    return value in get_args(KVCacheLayoutType)
52

53
54
55
56

@dataclass
class CommonAttentionMetadata:
    """
57
58
    Per-batch attention metadata, shared across layers and backends.
    AttentionMetadataBuilder instances use it to construct per-layer metadata.
59

60
    For many of the tensors we keep both GPU and CPU versions.
61
62
63
    """

    query_start_loc: torch.Tensor
64
    query_start_loc_cpu: torch.Tensor
65
    """(batch_size + 1,), the start location of each request in query Tensor"""
66

67
    seq_lens: torch.Tensor
68
    seq_lens_cpu: torch.Tensor
69
70
    """(batch_size,), the length of each request including both computed tokens
    and newly scheduled tokens"""
71

72
73
74
    num_computed_tokens_cpu: torch.Tensor
    """(batch_size,), the number of computed tokens for each request"""

75
76
77
78
79
80
    num_reqs: int
    """Number of requests"""
    num_actual_tokens: int
    """Total number of tokens in batch"""
    max_query_len: int
    """Longest query in batch"""
81
82
    max_seq_len: int
    """Longest context length in batch"""
83

84
85
86
    block_table_tensor: torch.Tensor
    slot_mapping: torch.Tensor

87
88
    causal: bool = True

89
90
91
92
    # Needed by FastPrefillAttentionBuilder
    logits_indices_padded: Optional[torch.Tensor] = None
    num_logits_indices: Optional[int] = None

93
94
95
    # Needed by CrossAttentionBuilder
    encoder_seq_lens: Optional[np.ndarray] = None

96
97
98
    dcp_local_seq_lens: Optional[torch.Tensor] = None
    """Sequence lengths of the local rank in decode context parallelism world"""

99

100
101
102
103
104
def slice_query_start_locs(
    query_start_loc: torch.Tensor,
    request_slice: slice,
) -> torch.Tensor:
    """
105
    Creates a new query_start_loc that corresponds to the requests in
106
107
108
109
110
    request_slice.

    Note: This function creates a new tensor to hold the new query_start_locs.
    This will break cudagraph compatibility.
    """
111
112
113
114
    return (
        query_start_loc[request_slice.start : request_slice.stop + 1]
        - query_start_loc[request_slice.start]
    )
115
116
117


def _make_metadata_with_slice(
118
119
    ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
120
    """
121
    This function creates a new CommonAttentionMetadata that corresponds to
122
123
124
    the requests included in ubatch_slice
    """

125
    assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
126

127
128
129
    request_slice = ubatch_slice.request_slice
    token_slice = ubatch_slice.token_slice

130
131
132
133
134
135
    start_locs = attn_metadata.query_start_loc_cpu
    first_req = request_slice.start
    first_tok = token_slice.start
    last_req = request_slice.stop - 1
    last_tok = token_slice.stop - 1

136
    assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
137
        "Token slice start outside of first request"
138
139
    )
    assert start_locs[last_req] <= last_tok < start_locs[last_req + 1], (
140
        "Token slice end outside of last request"
141
    )
142
143
144
145
146
147
148
149
150

    # If the "middle" request has tokens in both ubatches, we have to split it.
    # If ubatch_slice is the first ubatch then we will be splitting the last
    # request. If it's the second microbatch, then we will be splitting the
    # first request
    splits_first_request = first_tok > start_locs[first_req]
    splits_last_request = last_tok < start_locs[last_req + 1] - 1

    query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
151
152
153
    query_start_loc = slice_query_start_locs(
        attn_metadata.query_start_loc, request_slice
    )
154

155
    assert len(query_start_loc) >= 2, (
156
157
        f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
    )
158

159
160
161
162
    if splits_first_request:
        tokens_skipped = first_tok - start_locs[first_req]
        query_start_loc[1:] -= tokens_skipped
        query_start_loc_cpu[1:] -= tokens_skipped
163
164
    seq_lens = attn_metadata.seq_lens[request_slice]
    seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
165
166
167
168
169
170
171
172
173
174
175
176
177

    if splits_last_request:
        tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
        query_start_loc[-1] -= tokens_skipped
        query_start_loc_cpu[-1] -= tokens_skipped

        # Make sure we don't modify the seq_lens tensors
        #  (not cudagraph compatible)
        seq_lens = seq_lens.clone()
        seq_lens_cpu = seq_lens_cpu.clone()
        seq_lens[-1] -= tokens_skipped
        seq_lens_cpu[-1] -= tokens_skipped

178
    max_seq_len = int(seq_lens_cpu.max())
179
    num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
180
181
182
183

    num_requests = request_slice.stop - request_slice.start
    num_actual_tokens = token_slice.stop - token_slice.start
    max_query_len = int(
184
185
        torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
    )
186

187
188
189
190
191
    # This is to account for the case where we are in a dummy
    # run and query_start_loc_cpu is full of 0s
    if max_query_len == 0:
        max_query_len = attn_metadata.max_query_len

192
193
194
195
196
197
198
199
200
201
202
203
    block_table_tensor = attn_metadata.block_table_tensor[request_slice]
    slot_mapping = attn_metadata.slot_mapping[token_slice]

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens_cpu,
        num_computed_tokens_cpu=num_computed_tokens_cpu,
        num_reqs=num_requests,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
204
        max_seq_len=max_seq_len,
205
206
207
208
209
210
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
    )


def split_attn_metadata(
211
    ubatch_slices: list[UBatchSlice],
212
213
214
    common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
    """
215
    Creates a new CommonAttentionMetadata instance that corresponds to the
216
    requests for each UBatchSlice in ubatch_slices.
217
218
219
220
221

    Note: This function does not modify common_attn_metadata
    """
    results = []
    for ubatch_slice in ubatch_slices:
222
        results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
223

224
225
226
    return results


227
228
229
M = TypeVar("M")


230
class AttentionCGSupport(enum.Enum):
231
    """Constants for the cudagraph support of the attention backend
232
233
234
    Here we do not consider the cascade attention, as currently
    it is never cudagraph supported."""

235
236
237
238
    ALWAYS = 3
    """Cudagraph always supported; supports mixed-prefill-decode"""
    UNIFORM_BATCH = 2
    """Cudagraph supported for batches the only contain query lengths that are
239
    the same, this can be used for spec-decode
240
241
242
        i.e. "decodes" are 1 + num_speculative_tokens"""
    UNIFORM_SINGLE_TOKEN_DECODE = 1
    """Cudagraph supported for batches the only contain query_len==1 decodes"""
243
244
245
246
    NEVER = 0
    """NO cudagraph support"""


247
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
248
    # Does this backend/builder support CUDA Graphs for attention (default: no).
249
    cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
250
251
252
    # Does this backend/builder reorder the batch?
    # If not, set this to None. Otherwise set it to the query
    # length that will be pulled into the front of the batch.
253
    reorder_batch_threshold: Optional[int] = None
254
255

    @abstractmethod
256
257
258
259
260
261
262
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
263
        self.kv_cache_spec = kv_cache_spec
264
265
266
        self.layer_names = layer_names
        self.vllm_config = vllm_config
        self.device = device
267

268
    def _init_reorder_batch_threshold(
269
270
        self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False
    ) -> None:
271
        self.reorder_batch_threshold = reorder_batch_threshold
272
        if self.reorder_batch_threshold is not None and supports_spec_as_decode:
273
274
275
276
            # If the backend supports spec-as-decode kernels, then we can set
            # the reorder_batch_threshold based on the number of speculative
            # tokens from the config.
            speculative_config = self.vllm_config.speculative_config
277
278
279
280
            if (
                speculative_config is not None
                and speculative_config.num_speculative_tokens is not None
            ):
281
282
283
                self.reorder_batch_threshold = max(
                    self.reorder_batch_threshold,
                    1 + speculative_config.num_speculative_tokens,
284
                )
285

286
    @abstractmethod
287
288
289
290
291
292
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> M:
293
294
295
        """
        Central method that builds attention metadata.
        Some builders (MLA) require reorder_batch to be called prior to build.
296

297
298
299
300
301
302
        Args:
            common_prefix_len: The length of the common prefix of the batch.
            common_attn_metadata: The common attention metadata.
            fast_build: The meta-data will prioritize speed of building over
                then speed at execution. Can be used for spec-decode where the
                result of a build call may only be used for few layers/iters.
303
304
305
306
        """
        raise NotImplementedError

    def build_for_cudagraph_capture(
307
308
        self, common_attn_metadata: CommonAttentionMetadata
    ) -> M:
309
310
311
312
313
        """
        Build attention metadata for CUDA graph capture. Uses build by default.
        Subclasses that override this method should call self.build or
        super().build_for_cudagraph_capture.
        """
314
315
316
        return self.build(
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
317

318
319
320
321
322
323
324
    def build_for_drafting(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        draft_index: int,
    ) -> M:
        """
        Build attention metadata for draft model. Uses build by default.
325

326
327
328
329
330
331
332
333
        Args:
            common_attn_metadata: The common attention metadata.
            draft_index: The index of the current draft operation.
                When speculating a chain of tokens, this index refers to the
                draft attempt for the i-th token.
                For tree-based attention, this index instead refers to the
                draft attempt for the i-th level in the tree of tokens.
        """
334
335
336
337
338
        return self.build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
            fast_build=True,
        )
339

340
341
342
343
344
345
346
347
    def use_cascade_attention(
        self,
        common_prefix_len: int,
        query_lens: np.ndarray,
        num_query_heads: int,
        num_kv_heads: int,
        use_alibi: bool,
        use_sliding_window: bool,
348
        use_local_attention: bool,
349
350
351
352
        num_sms: int,
    ) -> bool:
        return False

353

354
355
@functools.lru_cache
def get_kv_cache_layout():
356
    # Format specified by the code.
357
    global _KV_CACHE_LAYOUT_OVERRIDE
358
359
360

    if _KV_CACHE_LAYOUT_OVERRIDE is not None:
        cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
361
362
363
364
365
        logger.info_once(
            "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
            "Setting KV cache layout to %s.",
            cache_layout,
        )
366
367
368
        return cache_layout

    # Format specified by the user.
369
    cache_layout = envs.VLLM_KV_CACHE_LAYOUT
370
    # When neither the user nor the override specified a layout, get default
371
    if cache_layout is None:
372
        cache_layout = get_kv_connector_cache_layout()
373
    else:
374
        assert is_valid_kv_cache_layout(cache_layout)
375
376
377
378
379
        logger.info_once(
            "`VLLM_KV_CACHE_LAYOUT` environment variable "
            "detected. Setting KV cache layout to %s.",
            cache_layout,
        )
380
    return cache_layout
381
382


383
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
384
385
386
387
    global _KV_CACHE_LAYOUT_OVERRIDE
    _KV_CACHE_LAYOUT_OVERRIDE = cache_layout


388
389
390
391
@dataclass
class PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
392
393
394
    the same values for the following hyperparameters. Should not be used for
    trtllm-gen backend since it supports different values for the following
    hyperparameters.
395
396
397
398
399
    """

    window_left: int
    logits_soft_cap: Optional[float]
    sm_scale: float
400
    has_sinks: bool = False
401
402
403
    # has same params for all layers
    has_same_window_lefts: Optional[bool] = field(default=None, compare=False)
    has_same_all_params: Optional[bool] = field(default=None, compare=False)
404
405
406


def get_per_layer_parameters(
407
408
    vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
) -> dict[str, PerLayerParameters]:
409
    """
410
    Scan layers in `layer_names` and determine some hyperparameters
411
412
413
    to use during `plan`.
    """

414
    layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
415
416
417
418
419
420
421
422
423
424
425
    per_layer_params: dict[str, PerLayerParameters] = {}

    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, cls_)

        # Infer hyperparameters from the attention layer
        window_size = getattr(impl, "sliding_window", None)
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = getattr(impl, "logits_soft_cap", None)
        sm_scale = impl.scale
426
        has_sinks = getattr(impl, "sinks", None) is not None
427

428
429
430
        per_layer_params[key] = PerLayerParameters(
            window_left, logits_soft_cap, sm_scale, has_sinks
        )
431
432
433
434
435

    return per_layer_params


def infer_global_hyperparameters(
436
437
    per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
438
    """
439
    Currently, FlashInfer backend other than trtllm-gen
440
    only support models in which all layers share
441
442
443
444
445
446
447
448
449
450
451
452
453
    the same values for the following hyperparameters:
    - `window_left`
    - `logits_soft_cap`
    - `sm_scale`

    So this function asserts that all layers share the same values for these
    hyperparameters and returns the global values.
    """

    assert len(per_layer_params) > 0, "No attention layers found in the model."

    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]
454

455
456
457
458
459
460
    global_params.has_same_window_lefts = all(
        params.window_left == global_params.window_left for params in param_sets
    )
    global_params.has_same_all_params = all(
        params == global_params for params in param_sets
    )
461
462
463
464

    return global_params


465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
#   q_seqlens  = [4, 10, 5]
#   kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
#  for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
#   batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 | 1 1 1 1 1
#               3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
#  attention mask like:
#   batch idx: 0  (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 |         1
#               3 |         1 1
#
# We can simulate this mask using standard flash-attention by breaking the
#  sequences into local ("virtual") batches, where each local batch item is a
#  local attention block, so in this case batch idx 0 would be broken up into:
#
#   local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4)  (batch 0)
#        k_toks >   0 1 2 3
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#   local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
#        k_toks >   4 5
#        q_toks v  _____________
#               2 | 1
#               3 | 1 1
#
# e.g. if we have:
#   attn_chunk_size = 4
#   query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
#                           __b0__  ______b1______  __b2__ < orig batch indices
#   q_seqlens_local    = [   2,  2,  1,  4,  4,  1,  4,  1]
#   cu_seqlens_q_local = [0, 4,  6, 10, 14, 18, 19, 23, 24]
#   seqlens_k_local    = [   4,  2,  4,  4,  4,  1,  4,  1]
#   block_table_local  : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
    attn_chunk_size: int,
519
    common_attn_metadata: CommonAttentionMetadata,
520
    block_size: int = 0,
521
522
523
524
525
526
) -> CommonAttentionMetadata:
    query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
    seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
    block_table = common_attn_metadata.block_table_tensor
    device = common_attn_metadata.query_start_loc.device

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
    actual_batch_size = seq_lens_np.shape[0]

    # Handle if we are starting in the middle of a local attention block,
    #  we assume q_seqlens > 0 (for all elements), for each batch idx we compute
    #  the number of tokens that are not in the first local attention block and
    #  then we can simply use a cdiv for the rest.
    # For example if we have:
    #   attn_chunk_size = 4
    #   q_seqlens = [4, 10, 5]
    #   k_seqlens = [6, 17, 9]
    # Then we would get:
    #   new_tokens_in_first_block = [2, 1, 4]
    #   local_blocks = [2, 4, 2]
    q_tokens_in_first_block = np.minimum(
542
543
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
    ).astype(np.int32)
544
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
545
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565

    # Once we know the number of local blocks we can compute the request spans
    #  for each batch idx, we can figure out the number of "virtual" requests we
    #  have to make,
    # For the above example we would get:
    #   seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
    #
    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
    #   (TODO: max a utility to share this code with _prepare_inputs)
    # arange step 1. [2, 4, 2] -> [2, 6, 8]
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
    # Then we can compute the seqlens_q_local, handling the fact that the
    #  first and last blocks could be partial
566
    seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
567
568
569
570
    # set the first block since this may be a partial block
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    # set the remaining blocks
    seqlens_q_local[arange > 0] = np.minimum(
571
572
        seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
    )[arange > 0]
573
574

    # convert from q_seqlens to cu_seqlens_q
575
576
577
    cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
    np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
    cu_seqlens_q_local[0] = 0
578
579
580
581
582
583

    # compute the seqlens_k_local,
    #  basically a full local attention block for all but the last block in each
    #  batch
    # For our example this will be:
    #   seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
584
    seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
585
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
586
    num_computed_tokens_local = seqlens_k_local - seqlens_q_local
587

588
589
590
    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
        rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
    )
591
592
593
594
    # For the example the local attention blocks start at:
    #                           _b0_  _____b1_____  _b2_
    #   k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
    block_starts = k_seqstarts_absolute // block_size
595
596
597
    assert attn_chunk_size % block_size == 0, (
        f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}"
    )
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    pages_per_local_batch = attn_chunk_size // block_size

    # Create a block_table for the local attention blocks
    # For out example if we have a block-table like (assuming block_size=2):
    #   block_table = [
    #     [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],  < batch 0
    #     [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  < batch 1
    #     [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],  < batch 2
    #   ]
    # Then for the local batches we would want a block-table like
    #   block_table_local = [
    #     [  0,  1 ], < local-batch 0, (batch 0, starting from k[0])
    #     [  2,  3 ], < local-batch 1, (batch 0, starting from k[4])
    #     [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
    #     [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
    #     [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
    #     [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
    #     [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
    #     [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
    #   ]
618
619
620
621
622
623
624
625
    block_indices = block_starts[:, None] + np.arange(
        pages_per_local_batch, dtype=np.int32
    )
    block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
    batch_indices = np.repeat(
        np.arange(actual_batch_size, dtype=np.int32),
        local_blocks * pages_per_local_batch,
    )
626
627
628
629
630
631
632

    # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
    # regression when using numpy arrays (batch and block indices) to index into
    # torch tensor (block_table). As a workaround, convert numpy arrays to torch
    # tensor first, which recovers perf.
    batch_indices_torch = torch.from_numpy(batch_indices)
    block_indices_torch = torch.from_numpy(block_indices)
633
634
635
    block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
        virtual_batches, -1
    )
636

637
638
    query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
    seq_lens_cpu = torch.from_numpy(seqlens_k_local)
639
    max_seq_len = int(seq_lens_cpu.max())
640
641
642

    return CommonAttentionMetadata(
        query_start_loc_cpu=query_start_loc_cpu,
643
        query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
644
645
646
647
648
649
        seq_lens_cpu=seq_lens_cpu,
        seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
        num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
        num_reqs=len(seq_lens_cpu),
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        max_query_len=seqlens_q_local.max(),
650
        max_seq_len=max_seq_len,
651
652
        block_table_tensor=block_table_local,
        slot_mapping=common_attn_metadata.slot_mapping,
653
        causal=True,
654
    )
655
656


657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
def make_kv_sharing_fast_prefill_common_attn_metadata(
    common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
    if common_attn_metadata.max_query_len == 1:
        # All requests are decode (assume 1 token for now)
        # Skip computing fast prefill path
        return common_attn_metadata

    assert common_attn_metadata.logits_indices_padded is not None
    assert common_attn_metadata.num_logits_indices is not None

    logits_indices_padded = common_attn_metadata.logits_indices_padded
    num_logits_indices = common_attn_metadata.num_logits_indices
    # Get rid of CUDAGraph padding, if any
    logits_indices = logits_indices_padded[:num_logits_indices]
    num_reqs = common_attn_metadata.num_reqs
    query_start_loc = common_attn_metadata.query_start_loc
    seq_lens = common_attn_metadata.seq_lens
    # Example inputs
    # num_reqs: 3
    # generation_indices:  [14, 18, 19, 27]
    # query_start_loc: [0, 15, 20, 28]
    # seq_lens:        [41, 31, 40]

    # Find how many decode indices belong to each request
    # request_ids: [0, 1, 1, 2]
683
    request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
684
685
686
687
688
689
690

    # Figure out how many tokens are in each request
    # num_decode_tokens: [1, 2, 1]
    num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)

    # Calculate new query_start_loc with tokens in generation_indices
    # decode_query_start_loc: [0, 1, 3, 4]
691
692
693
    decode_query_start_loc = torch.empty(
        num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype
    )
694
695
696
697
698
699
700
701

    decode_query_start_loc[0] = 0
    decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
    decode_max_query_len = int(num_decode_tokens.max().item())
    total_num_decode_tokens = int(num_decode_tokens.sum().item())

    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=decode_query_start_loc,
702
        query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
        num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
        num_reqs=num_reqs,
        num_actual_tokens=total_num_decode_tokens,
        max_query_len=decode_max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        block_table_tensor=common_attn_metadata.block_table_tensor,
        slot_mapping=common_attn_metadata.slot_mapping,
        causal=True,
    )
    return common_attn_metadata


717
def subclass_attention_backend(
718
719
720
    name_prefix: str,
    attention_backend_cls: type[AttentionBackend],
    builder_cls: type[AttentionMetadataBuilder[M]],
721
722
723
724
725
726
) -> type[AttentionBackend]:
    """
    Return a new subclass where `get_builder_cls` returns `builder_cls`.
    """
    name: str = name_prefix + attention_backend_cls.__name__  # type: ignore

727
728
729
    return type(
        name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
    )
730
731


732
def split_decodes_and_prefills(
733
734
735
736
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
    require_uniform: bool = False,
) -> tuple[int, int, int, int]:
737
738
739
740
741
742
743
744
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.
745
746
747
        require_uniform: If True, requires that all decode requests have the
            same query length. When set, some queries may be considered prefills
            even if they are <= decode_threshold, in order to ensure uniformity.
748
749
750
751
752
753
754
755
756
757
758
759

    Returns:
        num_decodes: The number of decode requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu

760
761
762
    if max_query_len <= decode_threshold and (
        not require_uniform or decode_threshold <= 1
    ):
763
764
765
        return num_reqs, 0, num_tokens, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
766
767
768
769
770
771
772
773
774
    if query_lens[0].item() > decode_threshold:
        # first request is not decode, so no decode requests
        return 0, num_reqs, 0, num_tokens

    if require_uniform:
        is_prefill = query_lens != query_lens[0]
    else:
        is_prefill = query_lens > decode_threshold

775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
    if not torch.any(is_prefill):
        return num_reqs, 0, num_tokens, 0

    first_prefill = is_prefill.int().argmax(dim=-1).item()
    assert torch.all(query_lens[:first_prefill] <= decode_threshold)
    num_decodes = first_prefill
    num_prefills = num_reqs - num_decodes
    num_decode_tokens = query_start_loc[first_prefill].item()
    num_prefill_tokens = num_tokens - num_decode_tokens
    return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)


def reorder_batch_to_split_decodes_and_prefills(
    input_batch: "InputBatch",
    scheduler_output: "SchedulerOutput",
    decode_threshold: int = 1,
) -> bool:
    """
    Reorders the batch to split into prefill and decode requests; places all
    requests with <= decode_threshold tokens at the front of the batch.
795

796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    Returns:
        True if the batch was modified, False otherwise.
    """
    # We now want to reorder the batch so that the "decode" requests are at
    # the front and the "prefill" requests are at the back using the least
    # amount of swaps possible. (NOTE for now we loosely use "decode" to mean
    # requests where attention is likely memory-bound and "prefill" to mean
    # requests where attention is likely compute-bound, TODO(lucas): figure out
    # a better naming here)
    decodes = []
    prefills = []
    num_decode_tokens = 0
    num_prefill_tokens = 0

    for i, req_id in enumerate(input_batch.req_ids):
        num_tokens = scheduler_output.num_scheduled_tokens[req_id]
        if num_tokens <= decode_threshold:
            decodes.append(i)
            num_decode_tokens += num_tokens
        else:
            prefills.append(i)
            num_prefill_tokens += num_tokens

    # We hope that this is fairly minimal since decodes
    # should be around for a number of iterations so hopefully they are
    # relatively stationary (and new request are generally appended to the
    # persistent batch so already should be at the back)
    # To achieve this we loop over the decodes in descending order and
    # the prefills in ascending order. We swap decodes from the  "back"
    # i.e. past where the last decode should be in the reodorered with
    # prefills from the front of the batch.
    # `decodes` and `prefills` are already in ascending order just based on
    # the above loop
    num_decodes = len(decodes)
    num_prefills = len(prefills)
    modified_batch = False

    for i in range(1, min(num_decodes, num_prefills) + 1):
        # If the decode is at the "back" of the batch, i, we can swap it
        # with the prefill closest to the front of the batch
        decode_idx = decodes[num_decodes - i]
        if decode_idx < num_decodes:
            break

        input_batch.swap_states(prefills[i - 1], decode_idx)
        modified_batch = True

    return modified_batch
844
845


846
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
847
848
849
850
851
852
853
854
855
    """
    Reshapes the query tensor for the specified batch size, so that
    it has shape (batch_size, seq_len, num_heads, head_dim).
    """
    assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
    total_tokens = query.shape[0]
    num_heads = query.shape[1]
    head_dim = query.shape[2]
    assert total_tokens % batch_size == 0, (
856
857
        f"{total_tokens=} is not divisible by {batch_size=}"
    )
858
859
860
861
    seq_len = total_tokens // batch_size
    return query.view(batch_size, seq_len, num_heads, head_dim)


862
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
863
864
865
866
867
868
869
    """
    Reshapes the attention output tensor, so that
    the batch_size and seq_len dimensions are combined.
    """
    if attn_output.dim() == 3:
        # Already in the correct shape
        return attn_output
870
    assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D"
871
    total_tokens = attn_output.shape[0] * attn_output.shape[1]
872
    return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
873
874


875
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
876
877
    ("logits_indices_padded", Optional[torch.Tensor], None),
    ("num_logits_indices", int, 0),
878
879
880
881
882
883
884
885
886
887
888
889
]


def subclass_attention_metadata(
    name_prefix: str,
    metadata_cls: Any,
    fields: list[tuple[str, Any, Any]],
) -> Any:
    """
    Return a new subclass of `metadata_cls` with additional fields
    """
    name: str = name_prefix + metadata_cls.__name__  # type: ignore
890
    Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
891
892
893
    return Wrapped


894
895
896
897
898
899
900
901
902
903
904
905
906
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
    logits_indices_padded: torch.Tensor
    num_logits_indices: int


def create_fast_prefill_custom_backend(
    prefix: str,
    underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
    underlying_builder = underlying_attn_backend.get_builder_cls()

    class FastPrefillAttentionBuilder(underlying_builder):  # type: ignore
907
908
909
910
911
912
913
914
915
916
917
918
        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
            new_common_attn_metadata = (
                make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
            )
            metadata = super().build(
                common_prefix_len, new_common_attn_metadata, fast_build
            )
919
920

            class KVSharingFastPrefillAttentionMetadata(
921
922
923
                metadata.__class__,  #  type: ignore
                KVSharingFastPrefillMetadata,
            ):
924
925
                def __init__(self, metadata, common_attn_metadata):
                    # Shallow copy all fields in metadata cls
926
927
                    for _field in fields(metadata.__class__):
                        setattr(self, _field.name, getattr(metadata, _field.name))
928
929

                    # Set additional fields that will be used in model code
930
931
932
933
934
                    assert (
                        common_attn_metadata.logits_indices_padded is not None
                        and common_attn_metadata.num_logits_indices is not None
                    )
                    self.logits_indices_padded = (
935
                        common_attn_metadata.logits_indices_padded
936
937
                    )
                    self.num_logits_indices = common_attn_metadata.num_logits_indices
938

939
            return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata)
940
941
942
943

    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
944
945
        builder_cls=FastPrefillAttentionBuilder,
    )
946
947

    return attn_backend
948
949
950
951


def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
    # Needed for causal_conv1d
952
    seqlens = query_start_loc_p.diff().to("cpu")
953
954
955
    nums_dict = {}  # type: ignore
    batch_ptr = None
    token_chunk_offset_ptr = None
956
    device = query_start_loc_p.device
957
958
959
    for BLOCK_M in [8]:  # cover all BLOCK_M values
        nums = -(-seqlens // BLOCK_M)
        nums_dict[BLOCK_M] = {}
960
961
        nums_dict[BLOCK_M]["nums"] = nums
        nums_dict[BLOCK_M]["tot"] = nums.sum().item()
962
        mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
963
964
965
        nums_dict[BLOCK_M]["mlist"] = mlist
        mlist_len = len(nums_dict[BLOCK_M]["mlist"])
        nums_dict[BLOCK_M]["mlist_len"] = mlist_len
966
967
968
969
970
        MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
        offsetlist = []  # type: ignore
        for idx, num in enumerate(nums):
            offsetlist.extend(range(num))
        offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
971
        nums_dict[BLOCK_M]["offsetlist"] = offsetlist
972
973
974

        if batch_ptr is None:
            # Update default value after class definition
975
976
977
978
979
980
            batch_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
            token_chunk_offset_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
981
982
983
984
        else:
            if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
                batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
                token_chunk_offset_ptr.resize_(  # type: ignore
985
986
                    MAX_NUM_PROGRAMS
                ).fill_(PAD_SLOT_ID)
987
988
989

        batch_ptr[0:mlist_len].copy_(mlist)
        token_chunk_offset_ptr[  # type: ignore
990
991
992
993
            0:mlist_len
        ].copy_(offsetlist)
        nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
        nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr  # type: ignore
994
995

    return nums_dict, batch_ptr, token_chunk_offset_ptr