utils.py 36.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
from collections.abc import Callable
5
from dataclasses import dataclass, field, fields, make_dataclass
6
7
8
9
10
11
12
13
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Protocol,
    TypeVar,
    get_args,
)
14

15
import numpy as np
16
import torch
17
from typing_extensions import runtime_checkable
18

19
from vllm.config import VllmConfig, get_layers_from_vllm_config
20
from vllm.utils.math_utils import cdiv
21

22
23
24
25
if TYPE_CHECKING:
    from vllm.v1.core.sched.output import SchedulerOutput
    from vllm.v1.worker.gpu_input_batch import InputBatch

26
27
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
28
29
    get_kv_connector_cache_layout,
)
30
from vllm.logger import init_logger
31
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
32
33
34
35
from vllm.v1.attention.backend import (
    AttentionBackend,
    AttentionImpl,
    AttentionMetadata,
36
37
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
38
)
39
from vllm.v1.worker.ubatch_utils import UBatchSlice
40
41

logger = init_logger(__name__)
42
KVCacheLayoutType = Literal["NHD", "HND"]
43
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
44

45
46
PAD_SLOT_ID = -1

47
48
49

def is_valid_kv_cache_layout(value: str) -> bool:
    return value in get_args(KVCacheLayoutType)
50

51

52
53
54
55
56
def slice_query_start_locs(
    query_start_loc: torch.Tensor,
    request_slice: slice,
) -> torch.Tensor:
    """
57
    Creates a new query_start_loc that corresponds to the requests in
58
59
60
61
62
    request_slice.

    Note: This function creates a new tensor to hold the new query_start_locs.
    This will break cudagraph compatibility.
    """
63
64
65
66
    return (
        query_start_loc[request_slice.start : request_slice.stop + 1]
        - query_start_loc[request_slice.start]
    )
67
68
69


def _make_metadata_with_slice(
70
71
    ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
72
    """
73
    This function creates a new CommonAttentionMetadata that corresponds to
74
75
76
    the requests included in ubatch_slice
    """

77
    assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
78

79
80
81
    request_slice = ubatch_slice.request_slice
    token_slice = ubatch_slice.token_slice

82
83
84
85
86
87
    start_locs = attn_metadata.query_start_loc_cpu
    first_req = request_slice.start
    first_tok = token_slice.start
    last_req = request_slice.stop - 1
    last_tok = token_slice.stop - 1

88
    assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
89
        "Token slice start outside of first request"
90
    )
91
    # NOTE: last token can be outside of the last request if we have CG padding.
92

93
94
95
96
97
    # If the request is split across ubatches, we have to adjust the metadata.
    # splits_first_request: The first request in this slice is the continuation of
    #                       a request that started in a previous slice.
    # splits_last_request:  The last request in this slice continues into the
    #                       next slice.
98
99
100
101
    splits_first_request = first_tok > start_locs[first_req]
    splits_last_request = last_tok < start_locs[last_req + 1] - 1

    query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
102
103
104
    query_start_loc = slice_query_start_locs(
        attn_metadata.query_start_loc, request_slice
    )
105

106
    assert len(query_start_loc) >= 2, (
107
108
        f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
    )
109

110
111
112
113
    if splits_first_request:
        tokens_skipped = first_tok - start_locs[first_req]
        query_start_loc[1:] -= tokens_skipped
        query_start_loc_cpu[1:] -= tokens_skipped
114
115
    seq_lens = attn_metadata.seq_lens[request_slice]
    seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
116
117

    if splits_last_request:
118
119
120
121
        # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
        # the tokens skipped because query_start_loc_cpu might have been modified
        # if splits_first_request is True.
        tokens_skipped = start_locs[last_req + 1] - token_slice.stop
122
123
124
125
126
127
128
129
130
131
        query_start_loc[-1] -= tokens_skipped
        query_start_loc_cpu[-1] -= tokens_skipped

        # Make sure we don't modify the seq_lens tensors
        #  (not cudagraph compatible)
        seq_lens = seq_lens.clone()
        seq_lens_cpu = seq_lens_cpu.clone()
        seq_lens[-1] -= tokens_skipped
        seq_lens_cpu[-1] -= tokens_skipped

132
    max_seq_len = int(seq_lens_cpu.max())
133
    num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
134
135
136
137

    num_requests = request_slice.stop - request_slice.start
    num_actual_tokens = token_slice.stop - token_slice.start
    max_query_len = int(
138
139
        torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
    )
140

141
142
143
144
145
    # This is to account for the case where we are in a dummy
    # run and query_start_loc_cpu is full of 0s
    if max_query_len == 0:
        max_query_len = attn_metadata.max_query_len

146
147
148
149
150
151
152
153
154
155
    block_table_tensor = attn_metadata.block_table_tensor[request_slice]
    slot_mapping = attn_metadata.slot_mapping[token_slice]

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        num_reqs=num_requests,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
156
        max_seq_len=max_seq_len,
157
158
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
159
160
        _seq_lens_cpu=seq_lens_cpu,
        _num_computed_tokens_cpu=num_computed_tokens_cpu,
161
162
163
164
    )


def split_attn_metadata(
165
    ubatch_slices: list[UBatchSlice],
166
167
168
    common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
    """
169
    Creates a new CommonAttentionMetadata instance that corresponds to the
170
    requests for each UBatchSlice in ubatch_slices.
171
172
173
174
175

    Note: This function does not modify common_attn_metadata
    """
    results = []
    for ubatch_slice in ubatch_slices:
176
        results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
177

178
179
180
    return results


181
182
@functools.lru_cache
def get_kv_cache_layout():
183
    # Format specified by the code.
184
    global _KV_CACHE_LAYOUT_OVERRIDE
185

186
    cache_layout: Literal["NHD", "HND"] | None = None
187
188
    if _KV_CACHE_LAYOUT_OVERRIDE is not None:
        cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
189
190
191
192
193
        logger.info_once(
            "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
            "Setting KV cache layout to %s.",
            cache_layout,
        )
194
195
196
        return cache_layout

    # Format specified by the user.
197
    cache_layout = envs.VLLM_KV_CACHE_LAYOUT
198
    # When neither the user nor the override specified a layout, get default
199
    if cache_layout is None:
200
        cache_layout = get_kv_connector_cache_layout()
201
    else:
202
        assert is_valid_kv_cache_layout(cache_layout)
203
204
205
206
207
        logger.info_once(
            "`VLLM_KV_CACHE_LAYOUT` environment variable "
            "detected. Setting KV cache layout to %s.",
            cache_layout,
        )
208
    return cache_layout
209
210


211
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
212
213
214
215
    global _KV_CACHE_LAYOUT_OVERRIDE
    _KV_CACHE_LAYOUT_OVERRIDE = cache_layout


216
217
218
219
@dataclass
class PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
220
221
222
    the same values for the following hyperparameters. Should not be used for
    trtllm-gen backend since it supports different values for the following
    hyperparameters.
223
224
225
    """

    window_left: int
226
    logits_soft_cap: float | None
227
    sm_scale: float
228
    has_sinks: bool = False
229
    # has same params for all layers
230
231
    has_same_window_lefts: bool | None = field(default=None, compare=False)
    has_same_all_params: bool | None = field(default=None, compare=False)
232
233
234


def get_per_layer_parameters(
235
236
    vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
) -> dict[str, PerLayerParameters]:
237
    """
238
    Scan layers in `layer_names` and determine some hyperparameters
239
240
241
    to use during `plan`.
    """

242
243
244
245
246
    layers = get_layers_from_vllm_config(
        vllm_config,
        AttentionLayerBase,  # type: ignore[type-abstract]
        layer_names,
    )
247
248
249
250
251
252
253
254
255
256
257
    per_layer_params: dict[str, PerLayerParameters] = {}

    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, cls_)

        # Infer hyperparameters from the attention layer
        window_size = getattr(impl, "sliding_window", None)
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = getattr(impl, "logits_soft_cap", None)
        sm_scale = impl.scale
258
        has_sinks = getattr(impl, "sinks", None) is not None
259

260
261
262
        per_layer_params[key] = PerLayerParameters(
            window_left, logits_soft_cap, sm_scale, has_sinks
        )
263
264
265
266
267

    return per_layer_params


def infer_global_hyperparameters(
268
269
    per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
270
    """
271
    Currently, FlashInfer backend other than trtllm-gen
272
    only support models in which all layers share
273
274
275
276
277
278
279
280
281
282
283
284
285
    the same values for the following hyperparameters:
    - `window_left`
    - `logits_soft_cap`
    - `sm_scale`

    So this function asserts that all layers share the same values for these
    hyperparameters and returns the global values.
    """

    assert len(per_layer_params) > 0, "No attention layers found in the model."

    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]
286

287
288
289
290
291
292
    global_params.has_same_window_lefts = all(
        params.window_left == global_params.window_left for params in param_sets
    )
    global_params.has_same_all_params = all(
        params == global_params for params in param_sets
    )
293
294
295
296

    return global_params


297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
#   q_seqlens  = [4, 10, 5]
#   kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
#  for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
#   batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 | 1 1 1 1 1
#               3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
#  attention mask like:
#   batch idx: 0  (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
#        k_toks >   0 1 2 3 4 5
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#               2 |         1
#               3 |         1 1
#
# We can simulate this mask using standard flash-attention by breaking the
#  sequences into local ("virtual") batches, where each local batch item is a
#  local attention block, so in this case batch idx 0 would be broken up into:
#
#   local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4)  (batch 0)
#        k_toks >   0 1 2 3
#        q_toks v  _____________
#               0 | 1 1 1
#               1 | 1 1 1 1
#   local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
#        k_toks >   4 5
#        q_toks v  _____________
#               2 | 1
#               3 | 1 1
#
# e.g. if we have:
#   attn_chunk_size = 4
#   query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
#                           __b0__  ______b1______  __b2__ < orig batch indices
#   q_seqlens_local    = [   2,  2,  1,  4,  4,  1,  4,  1]
#   cu_seqlens_q_local = [0, 4,  6, 10, 14, 18, 19, 23, 24]
#   seqlens_k_local    = [   4,  2,  4,  4,  4,  1,  4,  1]
#   block_table_local  : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
    attn_chunk_size: int,
351
    common_attn_metadata: CommonAttentionMetadata,
352
    block_size: int = 0,
353
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
354
355
356
357
358
    query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
    seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
    block_table = common_attn_metadata.block_table_tensor
    device = common_attn_metadata.query_start_loc.device

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
    actual_batch_size = seq_lens_np.shape[0]

    # Handle if we are starting in the middle of a local attention block,
    #  we assume q_seqlens > 0 (for all elements), for each batch idx we compute
    #  the number of tokens that are not in the first local attention block and
    #  then we can simply use a cdiv for the rest.
    # For example if we have:
    #   attn_chunk_size = 4
    #   q_seqlens = [4, 10, 5]
    #   k_seqlens = [6, 17, 9]
    # Then we would get:
    #   new_tokens_in_first_block = [2, 1, 4]
    #   local_blocks = [2, 4, 2]
    q_tokens_in_first_block = np.minimum(
374
375
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
    ).astype(np.int32)
376
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
377
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

    # Once we know the number of local blocks we can compute the request spans
    #  for each batch idx, we can figure out the number of "virtual" requests we
    #  have to make,
    # For the above example we would get:
    #   seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
    #
    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
    #   (TODO: max a utility to share this code with _prepare_inputs)
    # arange step 1. [2, 4, 2] -> [2, 6, 8]
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
    # Then we can compute the seqlens_q_local, handling the fact that the
    #  first and last blocks could be partial
398
    seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
399
400
401
402
    # set the first block since this may be a partial block
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    # set the remaining blocks
    seqlens_q_local[arange > 0] = np.minimum(
403
404
        seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
    )[arange > 0]
405
406

    # convert from q_seqlens to cu_seqlens_q
407
408
409
    cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
    np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
    cu_seqlens_q_local[0] = 0
410
411
412
413
414
415

    # compute the seqlens_k_local,
    #  basically a full local attention block for all but the last block in each
    #  batch
    # For our example this will be:
    #   seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
416
    seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
417
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
418
    num_computed_tokens_local = seqlens_k_local - seqlens_q_local
419

420
421
422
    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
        rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
    )
423
424
425
426
    # For the example the local attention blocks start at:
    #                           _b0_  _____b1_____  _b2_
    #   k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
    block_starts = k_seqstarts_absolute // block_size
427
428
429
    assert attn_chunk_size % block_size == 0, (
        f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}"
    )
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    pages_per_local_batch = attn_chunk_size // block_size

    # Create a block_table for the local attention blocks
    # For out example if we have a block-table like (assuming block_size=2):
    #   block_table = [
    #     [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],  < batch 0
    #     [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  < batch 1
    #     [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],  < batch 2
    #   ]
    # Then for the local batches we would want a block-table like
    #   block_table_local = [
    #     [  0,  1 ], < local-batch 0, (batch 0, starting from k[0])
    #     [  2,  3 ], < local-batch 1, (batch 0, starting from k[4])
    #     [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
    #     [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
    #     [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
    #     [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
    #     [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
    #     [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
    #   ]
450
451
452
453
454
455
456
457
    block_indices = block_starts[:, None] + np.arange(
        pages_per_local_batch, dtype=np.int32
    )
    block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
    batch_indices = np.repeat(
        np.arange(actual_batch_size, dtype=np.int32),
        local_blocks * pages_per_local_batch,
    )
458
459
460
461
462
463
464

    # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
    # regression when using numpy arrays (batch and block indices) to index into
    # torch tensor (block_table). As a workaround, convert numpy arrays to torch
    # tensor first, which recovers perf.
    batch_indices_torch = torch.from_numpy(batch_indices)
    block_indices_torch = torch.from_numpy(block_indices)
465
466
467
468
469
470

    # Save as a lambda so we can return this for update_block_table
    make_block_table = lambda block_table: block_table[
        batch_indices_torch, block_indices_torch
    ].view(virtual_batches, -1)
    block_table_local = make_block_table(block_table)
471

472
473
    query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
    seq_lens_cpu = torch.from_numpy(seqlens_k_local)
474
    max_seq_len = int(seq_lens_cpu.max())
475
476
477

    return CommonAttentionMetadata(
        query_start_loc_cpu=query_start_loc_cpu,
478
        query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
479
480
481
482
        seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
        num_reqs=len(seq_lens_cpu),
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        max_query_len=seqlens_q_local.max(),
483
        max_seq_len=max_seq_len,
484
485
        block_table_tensor=block_table_local,
        slot_mapping=common_attn_metadata.slot_mapping,
486
        causal=True,
487
488
        _seq_lens_cpu=seq_lens_cpu,
        _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
489
    ), make_block_table
490
491


492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
def make_kv_sharing_fast_prefill_common_attn_metadata(
    common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
    if common_attn_metadata.max_query_len == 1:
        # All requests are decode (assume 1 token for now)
        # Skip computing fast prefill path
        return common_attn_metadata

    assert common_attn_metadata.logits_indices_padded is not None
    assert common_attn_metadata.num_logits_indices is not None

    logits_indices_padded = common_attn_metadata.logits_indices_padded
    num_logits_indices = common_attn_metadata.num_logits_indices
    # Get rid of CUDAGraph padding, if any
    logits_indices = logits_indices_padded[:num_logits_indices]
    num_reqs = common_attn_metadata.num_reqs
    query_start_loc = common_attn_metadata.query_start_loc
    # Example inputs
    # num_reqs: 3
    # generation_indices:  [14, 18, 19, 27]
    # query_start_loc: [0, 15, 20, 28]
    # seq_lens:        [41, 31, 40]

    # Find how many decode indices belong to each request
    # request_ids: [0, 1, 1, 2]
517
    request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
518
519
520
521
522
523
524

    # Figure out how many tokens are in each request
    # num_decode_tokens: [1, 2, 1]
    num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)

    # Calculate new query_start_loc with tokens in generation_indices
    # decode_query_start_loc: [0, 1, 3, 4]
525
526
527
    decode_query_start_loc = torch.empty(
        num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype
    )
528
529
530
531
532
533
534
535

    decode_query_start_loc[0] = 0
    decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
    decode_max_query_len = int(num_decode_tokens.max().item())
    total_num_decode_tokens = int(num_decode_tokens.sum().item())

    common_attn_metadata = CommonAttentionMetadata(
        query_start_loc=decode_query_start_loc,
536
        query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True),
537
        seq_lens=common_attn_metadata.seq_lens,
538
539
540
541
542
543
544
        num_reqs=num_reqs,
        num_actual_tokens=total_num_decode_tokens,
        max_query_len=decode_max_query_len,
        max_seq_len=common_attn_metadata.max_seq_len,
        block_table_tensor=common_attn_metadata.block_table_tensor,
        slot_mapping=common_attn_metadata.slot_mapping,
        causal=True,
545
546
        _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
        _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
547
548
549
550
    )
    return common_attn_metadata


551
552
553
M = TypeVar("M")


554
def subclass_attention_backend(
555
556
557
    name_prefix: str,
    attention_backend_cls: type[AttentionBackend],
    builder_cls: type[AttentionMetadataBuilder[M]],
558
559
560
561
562
563
) -> type[AttentionBackend]:
    """
    Return a new subclass where `get_builder_cls` returns `builder_cls`.
    """
    name: str = name_prefix + attention_backend_cls.__name__  # type: ignore

564
565
566
    return type(
        name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
    )
567
568


Patrick von Platen's avatar
Patrick von Platen committed
569
570
571
572
573
574
575
576
577
def subclass_attention_backend_with_overrides(
    name_prefix: str,
    attention_backend_cls: type[AttentionBackend],
    overrides: dict[str, Any],
) -> type[AttentionBackend]:
    name: str = name_prefix + attention_backend_cls.__name__  # type: ignore
    return type(name, (attention_backend_cls,), overrides)


578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def split_decodes_prefills_and_extends(
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
) -> tuple[int, int, int, int, int, int]:
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.

    Returns:
        num_decodes: The number of decode requests.
        num_extends: The number of extend requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_extend_tokens: The number of tokens in the extend requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu
    seq_lens = common_attn_metadata.seq_lens_cpu

    if max_query_len <= decode_threshold:
        return num_reqs, 0, 0, num_tokens, 0, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
    is_prefill_or_extend = query_lens > decode_threshold
    is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
    first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
    first_prefill = is_prefill.int().argmax(dim=-1).item()
    num_decodes = first_extend
    num_decode_tokens = query_start_loc[first_extend].item()
    if not torch.any(is_prefill_or_extend):
        return (num_decodes, 0, 0, num_decode_tokens, 0, 0)

    num_prefills_or_extends = num_reqs - num_decodes
    num_prefill_or_extend_tokens = num_tokens - num_decode_tokens
    if not torch.any(is_prefill):
        return (
            num_decodes,
            num_prefills_or_extends,
            0,
            num_decode_tokens,
            num_prefill_or_extend_tokens,
            0,
        )

    num_extends = first_prefill - num_decodes
    num_prefills = num_reqs - first_prefill

    num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
    num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens
    return (
        num_decodes,
        num_extends,
        num_prefills,
        num_decode_tokens,
        num_extend_tokens,
        num_prefill_tokens,
    )


645
def split_decodes_and_prefills(
646
647
648
649
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
    require_uniform: bool = False,
) -> tuple[int, int, int, int]:
650
651
652
653
654
655
656
657
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.
658
659
660
        require_uniform: If True, requires that all decode requests have the
            same query length. When set, some queries may be considered prefills
            even if they are <= decode_threshold, in order to ensure uniformity.
661
662
663
664
665
666
667
668
669
670
671
672

    Returns:
        num_decodes: The number of decode requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu

673
674
675
    if max_query_len <= decode_threshold and (
        not require_uniform or decode_threshold <= 1
    ):
676
677
678
        return num_reqs, 0, num_tokens, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
679
680
681
682
683
    if query_lens[0].item() > decode_threshold:
        # first request is not decode, so no decode requests
        return 0, num_reqs, 0, num_tokens

    if require_uniform:
684
685
686
687
688
689
        # check if we are in a padded uniform batch; this is used for full-CGs, some
        # requests may have a query length of 0 but since they are padding its fine
        # to treat them as decodes (ensures num_decodes matches the captured size)
        if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
            assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly"
            return num_reqs, 0, num_tokens, 0  # all decodes
690
691
        is_prefill = query_lens != query_lens[0]
    else:
692
        is_prefill = query_lens > decode_threshold
693

694
695
696
697
698
699
700
701
702
703
704
705
    if not torch.any(is_prefill):
        return num_reqs, 0, num_tokens, 0

    first_prefill = is_prefill.int().argmax(dim=-1).item()
    assert torch.all(query_lens[:first_prefill] <= decode_threshold)
    num_decodes = first_prefill
    num_prefills = num_reqs - num_decodes
    num_decode_tokens = query_start_loc[first_prefill].item()
    num_prefill_tokens = num_tokens - num_decode_tokens
    return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)


706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
def split_prefill_chunks(
    seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
) -> list[tuple[int, int]]:
    """
    Split the prefill requests into chunks such that the total sequence length
    of each chunk is less than or equal to the workspace size.

    Args:
        seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
        workspace_size: The maximum workspace size (in tokens) per chunk.
        request_offset: The offset to add to the request indices.
    Returns:
        A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
    """
    chunk_bounds = []
    i, n = 0, len(seq_lens_cpu)
    assert torch.all(seq_lens_cpu <= workspace_size).item()

    while i < n:
        start, chunk_total = i, 0
        while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
            chunk_total += s
            i += 1
        chunk_bounds.append((start + request_offset, i + request_offset))
    return chunk_bounds


733
734
735
736
737
738
739
740
def reorder_batch_to_split_decodes_and_prefills(
    input_batch: "InputBatch",
    scheduler_output: "SchedulerOutput",
    decode_threshold: int = 1,
) -> bool:
    """
    Reorders the batch to split into prefill and decode requests; places all
    requests with <= decode_threshold tokens at the front of the batch.
741

742
743
744
    Returns:
        True if the batch was modified, False otherwise.
    """
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    # We now want to reorder the batch into decode → extend → prefill order
    # where:
    #   decode: request with num_scheduled_tokens <= decode_threshold
    #   extend: non-decode request with existing context
    #   prefill: non-decode request with no existing context
    # NOTE for now we loosely use "decode" to mean requests where attention is
    #  likely memory-bound and "prefill" to mean requests where attention is
    #  likely compute-bound,
    num_reqs = len(input_batch.req_ids)
    num_scheduled_tokens = [
        scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
    ]
    num_scheduled_tokens_np = np.array(num_scheduled_tokens)
    num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]

760
761
762
    is_prefill = num_computed_tokens_np == 0
    is_decode = (num_scheduled_tokens_np <= decode_threshold) & (~is_prefill)
    is_extend = (num_scheduled_tokens_np > decode_threshold) & (~is_prefill)
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781

    # Desired order: decode → extend → prefill
    req_regions = np.zeros(is_decode.shape, dtype=np.int32)  # 0 = decode by default
    req_regions[is_extend] = 1
    req_regions[is_prefill] = 2

    num_decodes = int(is_decode.sum())
    num_extends = int(is_extend.sum())

    target_regions = np.zeros(num_reqs, dtype=np.int32)
    target_regions[num_decodes : num_decodes + num_extends] = 1
    target_regions[num_decodes + num_extends :] = 2

    needs_swap = req_regions != target_regions

    if not needs_swap.any():
        return False

    # Extract indices that need swapping and sort by target region
782
    orig_indices = np.where(needs_swap)[0]
783
    sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
784
    src_indices = orig_indices[sorted_order]
785

786
    src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
787
788
789
790
791
792
793
794
795
796
797

    for src in src_dest_map:
        dst = src_dest_map[src]
        while src != dst:
            input_batch.swap_states(src, dst)
            # Mark dst as done by updating its destination to itself
            next_dst = src_dest_map.get(dst, dst)
            src_dest_map[dst] = dst
            dst = next_dst

    return True
798
799


800
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
801
802
803
804
805
806
807
808
809
    """
    Reshapes the query tensor for the specified batch size, so that
    it has shape (batch_size, seq_len, num_heads, head_dim).
    """
    assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
    total_tokens = query.shape[0]
    num_heads = query.shape[1]
    head_dim = query.shape[2]
    assert total_tokens % batch_size == 0, (
810
811
        f"{total_tokens=} is not divisible by {batch_size=}"
    )
812
813
814
815
    seq_len = total_tokens // batch_size
    return query.view(batch_size, seq_len, num_heads, head_dim)


816
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
817
818
819
820
821
822
823
    """
    Reshapes the attention output tensor, so that
    the batch_size and seq_len dimensions are combined.
    """
    if attn_output.dim() == 3:
        # Already in the correct shape
        return attn_output
824
    assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D"
825
    total_tokens = attn_output.shape[0] * attn_output.shape[1]
826
    return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
827
828


829
830
831
832
833
834
835
836
837
def subclass_attention_metadata(
    name_prefix: str,
    metadata_cls: Any,
    fields: list[tuple[str, Any, Any]],
) -> Any:
    """
    Return a new subclass of `metadata_cls` with additional fields
    """
    name: str = name_prefix + metadata_cls.__name__  # type: ignore
838
    Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
839
840
841
    return Wrapped


842
843
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
844
845
    logits_indices_padded: torch.Tensor | None = None
    num_logits_indices: int | None = None
846
847
848
849


def create_fast_prefill_custom_backend(
    prefix: str,
850
    underlying_attn_backend: type[AttentionBackend],
851
852
853
854
) -> type[AttentionBackend]:
    underlying_builder = underlying_attn_backend.get_builder_cls()

    class FastPrefillAttentionBuilder(underlying_builder):  # type: ignore
855
856
857
858
859
860
861
862
863
864
865
866
        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
            new_common_attn_metadata = (
                make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
            )
            metadata = super().build(
                common_prefix_len, new_common_attn_metadata, fast_build
            )
867
868

            class KVSharingFastPrefillAttentionMetadata(
869
870
871
                metadata.__class__,  #  type: ignore
                KVSharingFastPrefillMetadata,
            ):
872
873
                def __init__(self, metadata, common_attn_metadata):
                    # Shallow copy all fields in metadata cls
874
875
                    for _field in fields(metadata.__class__):
                        setattr(self, _field.name, getattr(metadata, _field.name))
876

877
                    self.logits_indices_padded = (
878
                        common_attn_metadata.logits_indices_padded
879
880
                    )
                    self.num_logits_indices = common_attn_metadata.num_logits_indices
881

882
            return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata)
883
884
885
886

    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
887
888
        builder_cls=FastPrefillAttentionBuilder,
    )
889
890

    return attn_backend
891
892
893
894


def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
    # Needed for causal_conv1d
895
    seqlens = query_start_loc_p.diff().to("cpu")
896
897
898
    nums_dict = {}  # type: ignore
    batch_ptr = None
    token_chunk_offset_ptr = None
899
    device = query_start_loc_p.device
900
901
902
    for BLOCK_M in [8]:  # cover all BLOCK_M values
        nums = -(-seqlens // BLOCK_M)
        nums_dict[BLOCK_M] = {}
903
904
        nums_dict[BLOCK_M]["nums"] = nums
        nums_dict[BLOCK_M]["tot"] = nums.sum().item()
905
        mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
906
907
908
        nums_dict[BLOCK_M]["mlist"] = mlist
        mlist_len = len(nums_dict[BLOCK_M]["mlist"])
        nums_dict[BLOCK_M]["mlist_len"] = mlist_len
909
910
911
912
913
        MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
        offsetlist = []  # type: ignore
        for idx, num in enumerate(nums):
            offsetlist.extend(range(num))
        offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
914
        nums_dict[BLOCK_M]["offsetlist"] = offsetlist
915
916
917

        if batch_ptr is None:
            # Update default value after class definition
918
919
920
921
922
923
            batch_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
            token_chunk_offset_ptr = torch.full(
                (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
            )
924
925
926
927
        else:
            if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
                batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
                token_chunk_offset_ptr.resize_(  # type: ignore
928
929
                    MAX_NUM_PROGRAMS
                ).fill_(PAD_SLOT_ID)
930
931
932

        batch_ptr[0:mlist_len].copy_(mlist)
        token_chunk_offset_ptr[  # type: ignore
933
934
935
936
            0:mlist_len
        ].copy_(offsetlist)
        nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr
        nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr  # type: ignore
937
938

    return nums_dict, batch_ptr, token_chunk_offset_ptr
939
940
941
942


def get_dcp_local_seq_lens(
    seq_lens: torch.Tensor,
943
    dcp_size: int = 1,
944
    dcp_rank: int | None = None,
945
    cp_kv_cache_interleave_size: int = 1,
946
947
948
949
950
951
952
953
) -> torch.Tensor:
    """While using dcp, kv_cache size stored on each rank may be different,
    use this function to calculate split decode seq_lens of each dcp rank.
    Only consider dcp now, we can extend the case of cp based on this.
    """
    num_requests = seq_lens.size(0)
    if dcp_rank is None:
        rank_offsets = (
954
            torch.arange(dcp_size, dtype=torch.int32, device=seq_lens.device)
955
956
957
958
            .unsqueeze(0)
            .repeat(num_requests, 1)
        )
    else:
959
960
961
        rank_offsets = torch.tensor(
            [[dcp_rank]], dtype=torch.int32, device=seq_lens.device
        )
962
963
964
965
966
    seq_lens_tiled = (
        seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
    )
    base = (
        seq_lens_tiled
967
968
969
        // cp_kv_cache_interleave_size
        // dcp_size
        * cp_kv_cache_interleave_size
970
    )
971
    remainder = seq_lens_tiled - base * dcp_size
972
    remainder = torch.clip(
973
        remainder - rank_offsets * cp_kv_cache_interleave_size,
974
        0,
975
        cp_kv_cache_interleave_size,
976
977
978
    )
    dcp_local_seq_lens = base + remainder
    return dcp_local_seq_lens.squeeze(1)