utils.py 47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import abc
4
import enum
5
import functools
6
from abc import abstractmethod
7
from collections.abc import Callable
8
from dataclasses import dataclass, field, fields, make_dataclass
9
10
11
12
13
14
15
16
17
18
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Generic,
    Literal,
    Protocol,
    TypeVar,
    get_args,
)
19

20
import numpy as np
21
import torch
22
from typing_extensions import deprecated, runtime_checkable
23

24
from vllm.config import VllmConfig, get_layers_from_vllm_config
25
from vllm.utils.math_utils import cdiv
26

27
28
29
30
if TYPE_CHECKING:
    from vllm.v1.core.sched.output import SchedulerOutput
    from vllm.v1.worker.gpu_input_batch import InputBatch

31
import vllm.envs as envs
32
33
34
35
36
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionImpl,
    AttentionMetadata,
)
37
from vllm.distributed.kv_transfer.kv_connector.utils import (
38
39
    get_kv_connector_cache_layout,
)
40
from vllm.logger import init_logger
41
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
42
from vllm.v1.kv_cache_interface import AttentionSpec
43
from vllm.v1.worker.ubatch_utils import UBatchSlice
44
45

logger = init_logger(__name__)
46
KVCacheLayoutType = Literal["NHD", "HND"]
47
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
48

49
50
PAD_SLOT_ID = -1

51
52
53

def is_valid_kv_cache_layout(value: str) -> bool:
    return value in get_args(KVCacheLayoutType)
54

55
56
57
58

@dataclass
class CommonAttentionMetadata:
    """
59
60
    Per-batch attention metadata, shared across layers and backends.
    AttentionMetadataBuilder instances use it to construct per-layer metadata.
61

62
    For many of the tensors we keep both GPU and CPU versions.
63
64
65
    """

    query_start_loc: torch.Tensor
66
    query_start_loc_cpu: torch.Tensor
67
    """(batch_size + 1,), the start location of each request in query Tensor"""
68

69
    seq_lens: torch.Tensor
70
71
    """(batch_size,), the number of computed tokens for each request"""

72
73
    num_reqs: int
    """Number of requests"""
74
    # TODO(lucas): rename to num_tokens since it may be padded and this is misleading
75
76
77
78
    num_actual_tokens: int
    """Total number of tokens in batch"""
    max_query_len: int
    """Longest query in batch"""
79
    max_seq_len: int
80
    """Longest context length (may be an upper bound)"""
81

82
83
84
    block_table_tensor: torch.Tensor
    slot_mapping: torch.Tensor

85
86
    causal: bool = True

87
    # Needed by FastPrefillAttentionBuilder
88
89
    logits_indices_padded: torch.Tensor | None = None
    num_logits_indices: int | None = None
90

91
    # Needed by CrossAttentionBuilder
92
93
    encoder_seq_lens: torch.Tensor | None = None
    encoder_seq_lens_cpu: np.ndarray | None = None
94

95
    dcp_local_seq_lens: torch.Tensor | None = None
96
    dcp_local_seq_lens_cpu: torch.Tensor | None = None
97
98
    """Sequence lengths of the local rank in decode context parallelism world"""

99
100
101
102
    # WARNING: Deprecated fields. Will be removed in a future release (v0.14.0)
    _seq_lens_cpu: torch.Tensor | None = None
    _num_computed_tokens_cpu: torch.Tensor | None = None

103
104
    _num_computed_tokens_cache: torch.Tensor | None = None

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    @property
    @deprecated(
        """
    Prefer using device seq_lens directly to avoid implicit H<>D sync.
    If a CPU copy is needed, use `seq_lens.cpu()` instead.
    Will be removed in a future release (v0.14.0)
    """
    )
    def seq_lens_cpu(self) -> torch.Tensor:
        if self._seq_lens_cpu is None:
            self._seq_lens_cpu = self.seq_lens.to("cpu")
        return self._seq_lens_cpu

    @property
    @deprecated(
        """
    Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
    async scheduling. If a CPU copy is needed, it can be derived from 
    query_start_loc_cpu and seq_lens.
    Will be removed in a future release (v0.14.0)
    """
    )
    def num_computed_tokens_cpu(self) -> torch.Tensor:
        if self._num_computed_tokens_cpu is None:
            query_seq_lens = (
                self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
            )
            self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
        return self._num_computed_tokens_cpu

135
136
137
138
139
140
141
    def compute_num_computed_tokens(self) -> torch.Tensor:
        """Compute num_computed_tokens on device (seq_lens - query_lens)."""
        if self._num_computed_tokens_cache is None:
            query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
            self._num_computed_tokens_cache = self.seq_lens - query_lens
        return self._num_computed_tokens_cache

142
143
144
145
146
147
148
149
150
    # TODO(lucas): remove once we have FULL-CG spec-decode support
    def unpadded(
        self, num_actual_tokens: int, num_actual_reqs: int
    ) -> "CommonAttentionMetadata":
        maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
        return CommonAttentionMetadata(
            query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
            query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
            seq_lens=self.seq_lens[:num_actual_reqs],
151
152
153
154
155
156
            _seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
            if self._seq_lens_cpu is not None
            else None,
            _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
            if self._num_computed_tokens_cpu is not None
            else None,
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            num_reqs=num_actual_reqs,
            num_actual_tokens=num_actual_tokens,
            max_query_len=self.max_query_len,
            max_seq_len=self.max_seq_len,
            block_table_tensor=self.block_table_tensor[:num_actual_reqs],
            slot_mapping=self.slot_mapping[:num_actual_tokens],
            causal=self.causal,
            logits_indices_padded=self.logits_indices_padded,
            num_logits_indices=self.num_logits_indices,
            encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
            encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
            dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
            dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
        )

172

173
174
175
176
177
def slice_query_start_locs(
    query_start_loc: torch.Tensor,
    request_slice: slice,
) -> torch.Tensor:
    """
178
    Creates a new query_start_loc that corresponds to the requests in
179
180
181
182
183
    request_slice.

    Note: This function creates a new tensor to hold the new query_start_locs.
    This will break cudagraph compatibility.
    """
184
185
186
187
    return (
        query_start_loc[request_slice.start : request_slice.stop + 1]
        - query_start_loc[request_slice.start]
    )
188
189
190


def _make_metadata_with_slice(
191
192
    ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
193
    """
194
    This function creates a new CommonAttentionMetadata that corresponds to
195
196
197
    the requests included in ubatch_slice
    """

198
    assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
199

200
201
202
    request_slice = ubatch_slice.request_slice
    token_slice = ubatch_slice.token_slice

203
204
205
206
207
208
    start_locs = attn_metadata.query_start_loc_cpu
    first_req = request_slice.start
    first_tok = token_slice.start
    last_req = request_slice.stop - 1
    last_tok = token_slice.stop - 1

209
    assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
210
        "Token slice start outside of first request"
211
    )
212
    # NOTE: last token can be outside of the last request if we have CG padding.
213

214
215
216
217
218
    # If the request is split across ubatches, we have to adjust the metadata.
    # splits_first_request: The first request in this slice is the continuation of
    #                       a request that started in a previous slice.
    # splits_last_request:  The last request in this slice continues into the
    #                       next slice.
219
220
221
222
    splits_first_request = first_tok > start_locs[first_req]
    splits_last_request = last_tok < start_locs[last_req + 1] - 1

    query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
223
224
225
    query_start_loc = slice_query_start_locs(
        attn_metadata.query_start_loc, request_slice
    )
226

227
    assert len(query_start_loc) >= 2, (
228
229
        f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
    )
230

231
232
233
234
    if splits_first_request:
        tokens_skipped = first_tok - start_locs[first_req]
        query_start_loc[1:] -= tokens_skipped
        query_start_loc_cpu[1:] -= tokens_skipped
235
236
    seq_lens = attn_metadata.seq_lens[request_slice]
    seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
237
238

    if splits_last_request:
239
240
241
242
        # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
        # the tokens skipped because query_start_loc_cpu might have been modified
        # if splits_first_request is True.
        tokens_skipped = start_locs[last_req + 1] - token_slice.stop
243
244
245
246
247
248
249
250
251
252
        query_start_loc[-1] -= tokens_skipped
        query_start_loc_cpu[-1] -= tokens_skipped

        # Make sure we don't modify the seq_lens tensors
        #  (not cudagraph compatible)
        seq_lens = seq_lens.clone()
        seq_lens_cpu = seq_lens_cpu.clone()
        seq_lens[-1] -= tokens_skipped
        seq_lens_cpu[-1] -= tokens_skipped

253
    max_seq_len = int(seq_lens_cpu.max())
254
    num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
255
256
257
258

    num_requests = request_slice.stop - request_slice.start
    num_actual_tokens = token_slice.stop - token_slice.start
    max_query_len = int(
259
260
        torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
    )
261

262
263
264
265
266
    # This is to account for the case where we are in a dummy
    # run and query_start_loc_cpu is full of 0s
    if max_query_len == 0:
        max_query_len = attn_metadata.max_query_len

267
268
269
270
271
272
273
274
275
276
    block_table_tensor = attn_metadata.block_table_tensor[request_slice]
    slot_mapping = attn_metadata.slot_mapping[token_slice]

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        num_reqs=num_requests,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
277
        max_seq_len=max_seq_len,
278
279
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
280
281
        _seq_lens_cpu=seq_lens_cpu,
        _num_computed_tokens_cpu=num_computed_tokens_cpu,
282
283
284
285
    )


def split_attn_metadata(
286
    ubatch_slices: list[UBatchSlice],
287
288
289
    common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
    """
290
    Creates a new CommonAttentionMetadata instance that corresponds to the
291
    requests for each UBatchSlice in ubatch_slices.
292
293
294
295
296

    Note: This function does not modify common_attn_metadata
    """
    results = []
    for ubatch_slice in ubatch_slices:
297
        results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
298

299
300
301
    return results


302
303
304
M = TypeVar("M")


305
class AttentionCGSupport(enum.Enum):
306
    """Constants for the cudagraph support of the attention backend
307
308
309
    Here we do not consider the cascade attention, as currently
    it is never cudagraph supported."""

310
311
312
313
    ALWAYS = 3
    """Cudagraph always supported; supports mixed-prefill-decode"""
    UNIFORM_BATCH = 2
    """Cudagraph supported for batches the only contain query lengths that are
314
    the same, this can be used for spec-decode
315
316
317
        i.e. "decodes" are 1 + num_speculative_tokens"""
    UNIFORM_SINGLE_TOKEN_DECODE = 1
    """Cudagraph supported for batches the only contain query_len==1 decodes"""
318
319
320
321
    NEVER = 0
    """NO cudagraph support"""


322
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
323
    # Does this backend/builder support CUDA Graphs for attention (default: no).
324
325
    # Do not access directly. Call get_cudagraph_support() instead.
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
326
327
328
    # Does this backend/builder reorder the batch?
    # If not, set this to None. Otherwise set it to the query
    # length that will be pulled into the front of the batch.
329
    reorder_batch_threshold: int | None = None
330
331
332
    # Does this backend/builder support updating the block table in existing
    # metadata
    supports_update_block_table: bool = False
333
334

    @abstractmethod
335
336
337
338
339
340
341
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
342
        self.kv_cache_spec = kv_cache_spec
343
344
345
        self.layer_names = layer_names
        self.vllm_config = vllm_config
        self.device = device
346

347
348
349
350
351
352
353
354
355
    @classmethod
    def get_cudagraph_support(
        cls: type["AttentionMetadataBuilder"],
        vllm_config: VllmConfig,
        kv_cache_spec: AttentionSpec,
    ) -> AttentionCGSupport:
        """Get the cudagraph support level of this builder class."""
        return cls._cudagraph_support

356
    def _init_reorder_batch_threshold(
357
        self,
358
        reorder_batch_threshold: int | None = 1,
359
360
        supports_spec_as_decode: bool = False,
        supports_dcp_with_varlen: bool = False,
361
    ) -> None:
362
        self.reorder_batch_threshold = reorder_batch_threshold
363
        if self.reorder_batch_threshold is not None and supports_spec_as_decode:
364
365
366
367
            # If the backend supports spec-as-decode kernels, then we can set
            # the reorder_batch_threshold based on the number of speculative
            # tokens from the config.
            speculative_config = self.vllm_config.speculative_config
368
369
370
371
            if (
                speculative_config is not None
                and speculative_config.num_speculative_tokens is not None
            ):
372
373
374
                self.reorder_batch_threshold = max(
                    self.reorder_batch_threshold,
                    1 + speculative_config.num_speculative_tokens,
375
                )
376

377
378
379
380
381
382
        if (
            self.vllm_config.parallel_config.decode_context_parallel_size > 1
            and not supports_dcp_with_varlen
        ):
            self.reorder_batch_threshold = 1

383
    @abstractmethod
384
385
386
387
388
389
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> M:
390
391
392
        """
        Central method that builds attention metadata.
        Some builders (MLA) require reorder_batch to be called prior to build.
393

394
395
396
397
398
399
        Args:
            common_prefix_len: The length of the common prefix of the batch.
            common_attn_metadata: The common attention metadata.
            fast_build: The meta-data will prioritize speed of building over
                then speed at execution. Can be used for spec-decode where the
                result of a build call may only be used for few layers/iters.
400
401
402
        """
        raise NotImplementedError

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    def update_block_table(
        self,
        metadata: M,
        blk_table: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> M:
        """
        Update the block table for the attention metadata.
        Faster when theres multiple kv-cache groups that create virtually the
        same metadata but just with different block tables.

        Only needs to be implemented if supports_update_block_table is True.
        """
        raise NotImplementedError

418
    def build_for_cudagraph_capture(
419
420
        self, common_attn_metadata: CommonAttentionMetadata
    ) -> M:
421
422
423
424
425
        """
        Build attention metadata for CUDA graph capture. Uses build by default.
        Subclasses that override this method should call self.build or
        super().build_for_cudagraph_capture.
        """
426
427
428
        return self.build(
            common_prefix_len=0, common_attn_metadata=common_attn_metadata
        )
429

430
431
432
433
434
435
436
    def build_for_drafting(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        draft_index: int,
    ) -> M:
        """
        Build attention metadata for draft model. Uses build by default.
437

438
439
440
441
442
443
444
445
        Args:
            common_attn_metadata: The common attention metadata.
            draft_index: The index of the current draft operation.
                When speculating a chain of tokens, this index refers to the
                draft attempt for the i-th token.
                For tree-based attention, this index instead refers to the
                draft attempt for the i-th level in the tree of tokens.
        """
446
447
448
449
450
        return self.build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
            fast_build=True,
        )
451

452
453
454
455
456
457
458
459
    def use_cascade_attention(
        self,
        common_prefix_len: int,
        query_lens: np.ndarray,
        num_query_heads: int,
        num_kv_heads: int,
        use_alibi: bool,
        use_sliding_window: bool,
460
        use_local_attention: bool,
461
        num_sms: int,
462
        dcp_world_size: int,
463
464
465
    ) -> bool:
        return False

466

467
468
@functools.lru_cache
def get_kv_cache_layout():
469
    # Format specified by the code.
470
    global _KV_CACHE_LAYOUT_OVERRIDE
471
472
473

    if _KV_CACHE_LAYOUT_OVERRIDE is not None:
        cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
474
475
476
477
478
        logger.info_once(
            "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
            "Setting KV cache layout to %s.",
            cache_layout,
        )
479
480
481
        return cache_layout

    # Format specified by the user.
482
    cache_layout = envs.VLLM_KV_CACHE_LAYOUT
483
    # When neither the user nor the override specified a layout, get default
484
    if cache_layout is None:
485
        cache_layout = get_kv_connector_cache_layout()
486
    else:
487
        assert is_valid_kv_cache_layout(cache_layout)
488
489
490
491
492
        logger.info_once(
            "`VLLM_KV_CACHE_LAYOUT` environment variable "
            "detected. Setting KV cache layout to %s.",
            cache_layout,
        )
493
    return cache_layout
494
495


496
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
497
498
499
500
    global _KV_CACHE_LAYOUT_OVERRIDE
    _KV_CACHE_LAYOUT_OVERRIDE = cache_layout


501
502
503
504
@dataclass
class PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
505
506
507
    the same values for the following hyperparameters. Should not be used for
    trtllm-gen backend since it supports different values for the following
    hyperparameters.
508
509
510
    """

    window_left: int
511
    logits_soft_cap: float | None
512
    sm_scale: float
513
    has_sinks: bool = False
514
    # has same params for all layers
515
516
    has_same_window_lefts: bool | None = field(default=None, compare=False)
    has_same_all_params: bool | None = field(default=None, compare=False)
517
518
519


def get_per_layer_parameters(
520
521
    vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
) -> dict[str, PerLayerParameters]:
522
    """
523
    Scan layers in `layer_names` and determine some hyperparameters
524
525
526
    to use during `plan`.
    """

527
    layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
528
529
530
531
532
533
534
535
536
537
538
    per_layer_params: dict[str, PerLayerParameters] = {}

    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, cls_)

        # Infer hyperparameters from the attention layer
        window_size = getattr(impl, "sliding_window", None)
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = getattr(impl, "logits_soft_cap", None)
        sm_scale = impl.scale
539
        has_sinks = getattr(impl, "sinks", None) is not None
540

541
542
543
        per_layer_params[key] = PerLayerParameters(
            window_left, logits_soft_cap, sm_scale, has_sinks
        )
544
545
546
547
548

    return per_layer_params


def infer_global_hyperparameters(
549
550
    per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
551
    """
552
    Currently, FlashInfer backend other than trtllm-gen
553
    only support models in which all layers share
554
555
556
557
558
559
560
561
562
563
564
565
566
    the same values for the following hyperparameters:
    - `window_left`
    - `logits_soft_cap`
    - `sm_scale`

    So this function asserts that all layers share the same values for these
    hyperparameters and returns the global values.
    """

    assert len(per_layer_params) > 0, "No attention layers found in the model."

    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]
567

568
569
570
571
572
573
    global_params.has_same_window_lefts = all(
        params.window_left == global_params.window_left for params in param_sets
    )
    global_params.has_same_all_params = all(
        params == global_params for params in param_sets
    )
574
575
576
577

    return global_params


578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
#   q_seqlens  = [4, 10, 5]
#   kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
#  for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
#   batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 | 1 1 1 1 1
#               3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
#  attention mask like:
#   batch idx: 0  (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 |         1
#               3 |         1 1
#
# We can simulate this mask using standard flash-attention by breaking the
#  sequences into local ("virtual") batches, where each local batch item is a
#  local attention block, so in this case batch idx 0 would be broken up into:
#
#   local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4)  (batch 0)
#        k_toks >   0 1 2 3
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#   local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
#        k_toks >   4 5
#        q_toks v  _____________
#               2 | 1
#               3 | 1 1
#
# e.g. if we have:
#   attn_chunk_size = 4
#   query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
#                           __b0__  ______b1______  __b2__ < orig batch indices
#   q_seqlens_local    = [   2,  2,  1,  4,  4,  1,  4,  1]
#   cu_seqlens_q_local = [0, 4,  6, 10, 14, 18, 19, 23, 24]
#   seqlens_k_local    = [   4,  2,  4,  4,  4,  1,  4,  1]
#   block_table_local  : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
    attn_chunk_size: int,
632
    common_attn_metadata: CommonAttentionMetadata,
633
    block_size: int = 0,
634
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
635
636
637
638
639
    query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
    seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
    block_table = common_attn_metadata.block_table_tensor
    device = common_attn_metadata.query_start_loc.device

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
    actual_batch_size = seq_lens_np.shape[0]

    # Handle if we are starting in the middle of a local attention block,
    #  we assume q_seqlens > 0 (for all elements), for each batch idx we compute
    #  the number of tokens that are not in the first local attention block and
    #  then we can simply use a cdiv for the rest.
    # For example if we have:
    #   attn_chunk_size = 4
    #   q_seqlens = [4, 10, 5]
    #   k_seqlens = [6, 17, 9]
    # Then we would get:
    #   new_tokens_in_first_block = [2, 1, 4]
    #   local_blocks = [2, 4, 2]
    q_tokens_in_first_block = np.minimum(
655
656
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
    ).astype(np.int32)
657
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
658
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678

    # Once we know the number of local blocks we can compute the request spans
    #  for each batch idx, we can figure out the number of "virtual" requests we
    #  have to make,
    # For the above example we would get:
    #   seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
    #
    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
    #   (TODO: max a utility to share this code with _prepare_inputs)
    # arange step 1. [2, 4, 2] -> [2, 6, 8]
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
    # Then we can compute the seqlens_q_local, handling the fact that the
    #  first and last blocks could be partial
679
    seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
680
681
682
683
    # set the first block since this may be a partial block
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    # set the remaining blocks
    seqlens_q_local[arange > 0] = np.minimum(
684
685
        seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
    )[arange > 0]
686
687

    # convert from q_seqlens to cu_seqlens_q
688
689
690
    cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
    np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
    cu_seqlens_q_local[0] = 0
691
692
693
694
695
696

    # compute the seqlens_k_local,
    #  basically a full local attention block for all but the last block in each
    #  batch
    # For our example this will be:
    #   seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
697
    seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
698
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
699
    num_computed_tokens_local = seqlens_k_local - seqlens_q_local
700

701
702
703
    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
        rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
    )
704
705
706
707
    # For the example the local attention blocks start at:
    #                           _b0_  _____b1_____  _b2_
    #   k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
    block_starts = k_seqstarts_absolute // block_size
708
709
710
    assert attn_chunk_size % block_size == 0, (
        f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}"
    )
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
    pages_per_local_batch = attn_chunk_size // block_size

    # Create a block_table for the local attention blocks
    # For out example if we have a block-table like (assuming block_size=2):
    #   block_table = [
    #     [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],  < batch 0
    #     [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  < batch 1
    #     [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],  < batch 2
    #   ]
    # Then for the local batches we would want a block-table like
    #   block_table_local = [
    #     [  0,  1 ], < local-batch 0, (batch 0, starting from k[0])
    #     [  2,  3 ], < local-batch 1, (batch 0, starting from k[4])
    #     [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
    #     [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
    #     [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
    #     [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
    #     [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
    #     [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
    #   ]
731
732
733
734
735
736
737
738
    block_indices = block_starts[:, None] + np.arange(
        pages_per_local_batch, dtype=np.int32
    )
    block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
    batch_indices = np.repeat(
        np.arange(actual_batch_size, dtype=np.int32),
        local_blocks * pages_per_local_batch,
    )
739
740
741
742
743
744
745

    # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
    # regression when using numpy arrays (batch and block indices) to index into
    # torch tensor (block_table). As a workaround, convert numpy arrays to torch
    # tensor first, which recovers perf.
    batch_indices_torch = torch.from_numpy(batch_indices)
    block_indices_torch = torch.from_numpy(block_indices)
746
747
748
749
750
751

    # Save as a lambda so we can return this for update_block_table
    make_block_table = lambda block_table: block_table[
        batch_indices_torch, block_indices_torch
    ].view(virtual_batches, -1)
    block_table_local = make_block_table(block_table)
752

753
754
    query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
    seq_lens_cpu = torch.from_numpy(seqlens_k_local)
755
    max_seq_len = int(seq_lens_cpu.max())
756
757
758

    return CommonAttentionMetadata(
        query_start_loc_cpu=query_start_loc_cpu,
759
        query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
760
761
762
763
        seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
        num_reqs=len(seq_lens_cpu),
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        max_query_len=seqlens_q_local.max(),
764
        max_seq_len=max_seq_len,
765
766
        block_table_tensor=block_table_local,
        slot_mapping=common_attn_metadata.slot_mapping,
767
        causal=True,
768
769
        _seq_lens_cpu=seq_lens_cpu,
        _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
770
    ), make_block_table
771
772


773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def make_kv_sharing_fast_prefill_common_attn_metadata(
    common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
    if common_attn_metadata.max_query_len == 1:
        # All requests are decode (assume 1 token for now)
        # Skip computing fast prefill path
        return common_attn_metadata

    assert common_attn_metadata.logits_indices_padded is not None
    assert common_attn_metadata.num_logits_indices is not None

    logits_indices_padded = common_attn_metadata.logits_indices_padded
    num_logits_indices = common_attn_metadata.num_logits_indices
    # Get rid of CUDAGraph padding, if any
    logits_indices = logits_indices_padded[:num_logits_indices]
    num_reqs = common_attn_metadata.num_reqs
    query_start_loc = common_attn_metadata.query_start_loc
    # Example inputs
    # num_reqs: 3
    # generation_indices:  [14, 18, 19, 27]
    # query_start_loc: [0, 15, 20, 28]
    # seq_lens:        [41, 31, 40]

    # Find how many decode indices belong to each request
    # request_ids: [0, 1, 1, 2]
798
    request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
799
800
801
802
803
804
805

    # Figure out how many tokens are in each request
    # num_decode_tokens: [1, 2, 1]
    num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)

    # Calculate new query_start_loc with tokens in generation_indices
    # decode_query_start_loc: [0, 1, 3, 4]
806
807
808
    decode_query_start_loc = torch.empty(
        num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype
    )
809
810
811
812
813
814
815
816

    decode_query_start_loc[0] = 0
    decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
    decode_max_query_len = int(num_decode_tokens.max().item())
    total_num_decode_tokens = int(num_decode_tokens.sum().item())

    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=decode_query_start_loc,
817
        query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
818
        seq_lens=common_attn_metadata.seq_lens,
819
820
821
822
823
824
825
        num_reqs=num_reqs,
        num_actual_tokens=total_num_decode_tokens,
        max_query_len=decode_max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        block_table_tensor=common_attn_metadata.block_table_tensor,
        slot_mapping=common_attn_metadata.slot_mapping,
        causal=True,
826
827
        _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
        _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
828
829
830
831
    )
    return common_attn_metadata


832
def subclass_attention_backend(
833
834
835
    name_prefix: str,
    attention_backend_cls: type[AttentionBackend],
    builder_cls: type[AttentionMetadataBuilder[M]],
836
837
838
839
840
841
) -> type[AttentionBackend]:
    """
    Return a new subclass where `get_builder_cls` returns `builder_cls`.
    """
    name: str = name_prefix + attention_backend_cls.__name__  # type: ignore

842
843
844
    return type(
        name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
    )
845
846


Patrick von Platen's avatar
Patrick von Platen committed
847
848
849
850
851
852
853
854
855
def subclass_attention_backend_with_overrides(
    name_prefix: str,
    attention_backend_cls: type[AttentionBackend],
    overrides: dict[str, Any],
) -> type[AttentionBackend]:
    name: str = name_prefix + attention_backend_cls.__name__  # type: ignore
    return type(name, (attention_backend_cls,), overrides)


856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
def split_decodes_prefills_and_extends(
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
) -> tuple[int, int, int, int, int, int]:
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.

    Returns:
        num_decodes: The number of decode requests.
        num_extends: The number of extend requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_extend_tokens: The number of tokens in the extend requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu
    seq_lens = common_attn_metadata.seq_lens_cpu

    if max_query_len <= decode_threshold:
        return num_reqs, 0, 0, num_tokens, 0, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
    is_prefill_or_extend = query_lens > decode_threshold
    is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
    first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
    first_prefill = is_prefill.int().argmax(dim=-1).item()
    num_decodes = first_extend
    num_decode_tokens = query_start_loc[first_extend].item()
    if not torch.any(is_prefill_or_extend):
        return (num_decodes, 0, 0, num_decode_tokens, 0, 0)

    num_prefills_or_extends = num_reqs - num_decodes
    num_prefill_or_extend_tokens = num_tokens - num_decode_tokens
    if not torch.any(is_prefill):
        return (
            num_decodes,
            num_prefills_or_extends,
            0,
            num_decode_tokens,
            num_prefill_or_extend_tokens,
            0,
        )

    num_extends = first_prefill - num_decodes
    num_prefills = num_reqs - first_prefill

    num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
    num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens
    return (
        num_decodes,
        num_extends,
        num_prefills,
        num_decode_tokens,
        num_extend_tokens,
        num_prefill_tokens,
    )


923
def split_decodes_and_prefills(
924
925
926
927
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
    require_uniform: bool = False,
) -> tuple[int, int, int, int]:
928
929
930
931
932
933
934
935
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.
936
937
938
        require_uniform: If True, requires that all decode requests have the
            same query length. When set, some queries may be considered prefills
            even if they are <= decode_threshold, in order to ensure uniformity.
939
940
941
942
943
944
945
946
947
948
949
950

    Returns:
        num_decodes: The number of decode requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu

951
952
953
    if max_query_len <= decode_threshold and (
        not require_uniform or decode_threshold <= 1
    ):
954
955
956
        return num_reqs, 0, num_tokens, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
957
958
959
960
961
    if query_lens[0].item() > decode_threshold:
        # first request is not decode, so no decode requests
        return 0, num_reqs, 0, num_tokens

    if require_uniform:
962
963
964
965
966
967
        # check if we are in a padded uniform batch; this is used for full-CGs, some
        # requests may have a query length of 0 but since they are padding its fine
        # to treat them as decodes (ensures num_decodes matches the captured size)
        if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
            assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
            return num_reqs, 0, num_tokens, 0  # all decodes
968
969
        is_prefill = query_lens != query_lens[0]
    else:
970
        is_prefill = query_lens > decode_threshold
971

972
973
974
975
976
977
978
979
980
981
982
983
    if not torch.any(is_prefill):
        return num_reqs, 0, num_tokens, 0

    first_prefill = is_prefill.int().argmax(dim=-1).item()
    assert torch.all(query_lens[:first_prefill] <= decode_threshold)
    num_decodes = first_prefill
    num_prefills = num_reqs - num_decodes
    num_decode_tokens = query_start_loc[first_prefill].item()
    num_prefill_tokens = num_tokens - num_decode_tokens
    return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)


984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
def split_prefill_chunks(
    seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
) -> list[tuple[int, int]]:
    """
    Split the prefill requests into chunks such that the total sequence length
    of each chunk is less than or equal to the workspace size.

    Args:
        seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
        workspace_size: The maximum workspace size (in tokens) per chunk.
        request_offset: The offset to add to the request indices.
    Returns:
        A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
    """
    chunk_bounds = []
    i, n = 0, len(seq_lens_cpu)
    assert torch.all(seq_lens_cpu <= workspace_size).item()

    while i < n:
        start, chunk_total = i, 0
        while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
            chunk_total += s
            i += 1
        chunk_bounds.append((start + request_offset, i + request_offset))
    return chunk_bounds


1011
1012
1013
1014
1015
1016
1017
1018
def reorder_batch_to_split_decodes_and_prefills(
    input_batch: "InputBatch",
    scheduler_output: "SchedulerOutput",
    decode_threshold: int = 1,
) -> bool:
    """
    Reorders the batch to split into prefill and decode requests; places all
    requests with <= decode_threshold tokens at the front of the batch.
1019

1020
1021
1022
    Returns:
        True if the batch was modified, False otherwise.
    """
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
    # We now want to reorder the batch into decode → extend → prefill order
    # where:
    #   decode: request with num_scheduled_tokens <= decode_threshold
    #   extend: non-decode request with existing context
    #   prefill: non-decode request with no existing context
    # NOTE for now we loosely use "decode" to mean requests where attention is
    #  likely memory-bound and "prefill" to mean requests where attention is
    #  likely compute-bound,
    num_reqs = len(input_batch.req_ids)
    num_scheduled_tokens = [
        scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
    ]
    num_scheduled_tokens_np = np.array(num_scheduled_tokens)
    num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]

    is_decode = num_scheduled_tokens_np <= decode_threshold
1039
1040
    is_extend = (~is_decode) & (num_computed_tokens_np > 0)
    is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059

    # Desired order: decode → extend → prefill
    req_regions = np.zeros(is_decode.shape, dtype=np.int32)  # 0 = decode by default
    req_regions[is_extend] = 1
    req_regions[is_prefill] = 2

    num_decodes = int(is_decode.sum())
    num_extends = int(is_extend.sum())

    target_regions = np.zeros(num_reqs, dtype=np.int32)
    target_regions[num_decodes : num_decodes + num_extends] = 1
    target_regions[num_decodes + num_extends :] = 2

    needs_swap = req_regions != target_regions

    if not needs_swap.any():
        return False

    # Extract indices that need swapping and sort by target region
1060
    orig_indices = np.where(needs_swap)[0]
1061
    sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
1062
    src_indices = orig_indices[sorted_order]
1063

1064
    src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075

    for src in src_dest_map:
        dst = src_dest_map[src]
        while src != dst:
            input_batch.swap_states(src, dst)
            # Mark dst as done by updating its destination to itself
            next_dst = src_dest_map.get(dst, dst)
            src_dest_map[dst] = dst
            dst = next_dst

    return True
1076
1077


1078
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
1079
1080
1081
1082
1083
1084
1085
1086
1087
    """
    Reshapes the query tensor for the specified batch size, so that
    it has shape (batch_size, seq_len, num_heads, head_dim).
    """
    assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
    total_tokens = query.shape[0]
    num_heads = query.shape[1]
    head_dim = query.shape[2]
    assert total_tokens % batch_size == 0, (
1088
1089
        f"{total_tokens=} is not divisible by {batch_size=}"
    )
1090
1091
1092
1093
    seq_len = total_tokens // batch_size
    return query.view(batch_size, seq_len, num_heads, head_dim)


1094
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
1095
1096
1097
1098
1099
1100
1101
    """
    Reshapes the attention output tensor, so that
    the batch_size and seq_len dimensions are combined.
    """
    if attn_output.dim() == 3:
        # Already in the correct shape
        return attn_output
1102
    assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D"
1103
    total_tokens = attn_output.shape[0] * attn_output.shape[1]
1104
    return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
1105
1106


1107
1108
1109
1110
1111
1112
1113
1114
1115
def subclass_attention_metadata(
    name_prefix: str,
    metadata_cls: Any,
    fields: list[tuple[str, Any, Any]],
) -> Any:
    """
    Return a new subclass of `metadata_cls` with additional fields
    """
    name: str = name_prefix + metadata_cls.__name__  # type: ignore
1116
    Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
1117
1118
1119
    return Wrapped


1120
1121
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
1122
1123
    logits_indices_padded: torch.Tensor | None = None
    num_logits_indices: int | None = None
1124
1125
1126
1127
1128
1129
1130
1131
1132


def create_fast_prefill_custom_backend(
    prefix: str,
    underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
    underlying_builder = underlying_attn_backend.get_builder_cls()

    class FastPrefillAttentionBuilder(underlying_builder):  # type: ignore
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
            new_common_attn_metadata = (
                make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
            )
            metadata = super().build(
                common_prefix_len, new_common_attn_metadata, fast_build
            )
1145
1146

            class KVSharingFastPrefillAttentionMetadata(
1147
1148
1149
                metadata.__class__,  #  type: ignore
                KVSharingFastPrefillMetadata,
            ):
1150
1151
                def __init__(self, metadata, common_attn_metadata):
                    # Shallow copy all fields in metadata cls
1152
1153
                    for _field in fields(metadata.__class__):
                        setattr(self, _field.name, getattr(metadata, _field.name))
1154

1155
                    self.logits_indices_padded = (
1156
                        common_attn_metadata.logits_indices_padded
1157
1158
                    )
                    self.num_logits_indices = common_attn_metadata.num_logits_indices
1159

1160
            return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata)
1161
1162
1163
1164

    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
1165
1166
        builder_cls=FastPrefillAttentionBuilder,
    )
1167
1168

    return attn_backend
1169
1170
1171
1172


def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
    # Needed for causal_conv1d
1173
    seqlens = query_start_loc_p.diff().to("cpu")
1174
1175
1176
    nums_dict = {}  # type: ignore
    batch_ptr = None
    token_chunk_offset_ptr = None
1177
    device = query_start_loc_p.device
1178
1179
1180
    for BLOCK_M in [8]:  # cover all BLOCK_M values
        nums = -(-seqlens // BLOCK_M)
        nums_dict[BLOCK_M] = {}
1181
1182
        nums_dict[BLOCK_M]["nums"] = nums
        nums_dict[BLOCK_M]["tot"] = nums.sum().item()
1183
        mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
1184
1185
1186
        nums_dict[BLOCK_M]["mlist"] = mlist
        mlist_len = len(nums_dict[BLOCK_M]["mlist"])
        nums_dict[BLOCK_M]["mlist_len"] = mlist_len
1187
1188
1189
1190
1191
        MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
        offsetlist = []  # type: ignore
        for idx, num in enumerate(nums):
            offsetlist.extend(range(num))
        offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
1192
        nums_dict[BLOCK_M]["offsetlist"] = offsetlist
1193
1194
1195

        if batch_ptr is None:
            # Update default value after class definition
1196
1197
1198
1199
1200
1201
            batch_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
            token_chunk_offset_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
1202
1203
1204
1205
        else:
            if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
                batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
                token_chunk_offset_ptr.resize_(  # type: ignore
1206
1207
                    MAX_NUM_PROGRAMS
                ).fill_(PAD_SLOT_ID)
1208
1209
1210

        batch_ptr[0:mlist_len].copy_(mlist)
        token_chunk_offset_ptr[  # type: ignore
1211
1212
1213
1214
            0:mlist_len
        ].copy_(offsetlist)
        nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
        nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr  # type: ignore
1215
1216

    return nums_dict, batch_ptr, token_chunk_offset_ptr
1217
1218
1219
1220


def get_dcp_local_seq_lens(
    seq_lens: torch.Tensor,
1221
    dcp_size: int = 1,
1222
    dcp_rank: int | None = None,
1223
    cp_kv_cache_interleave_size: int = 1,
1224
1225
1226
1227
1228
1229
1230
1231
) -> torch.Tensor:
    """While using dcp, kv_cache size stored on each rank may be different,
    use this function to calculate split decode seq_lens of each dcp rank.
    Only consider dcp now, we can extend the case of cp based on this.
    """
    num_requests = seq_lens.size(0)
    if dcp_rank is None:
        rank_offsets = (
1232
            torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device)
1233
1234
1235
1236
            .unsqueeze(0)
            .repeat(num_requests, 1)
        )
    else:
1237
1238
1239
        rank_offsets = torch.tensor(
            [[dcp_rank]], dtype=torch.int32, device=seq_lens.device
        )
1240
1241
1242
1243
1244
    seq_lens_tiled = (
        seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
    )
    base = (
        seq_lens_tiled
1245
1246
1247
        // cp_kv_cache_interleave_size
        // dcp_size
        * cp_kv_cache_interleave_size
1248
    )
1249
    remainder = seq_lens_tiled - base * dcp_size
1250
    remainder = torch.clip(
1251
        remainder - rank_offsets * cp_kv_cache_interleave_size,
1252
        0,
1253
        cp_kv_cache_interleave_size,
1254
1255
1256
    )
    dcp_local_seq_lens = base + remainder
    return dcp_local_seq_lens.squeeze(1)