"docs/vscode:/vscode.git/clone" did not exist on "3462c1c522d214755f1dfce3d645ab5afe7f00ae"
input_batch.py 18.1 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import numpy as np
import torch

8
from vllm.triton_utils import tl, triton
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from vllm.utils import random_uuid


class InputBuffers:
    def __init__(
        self,
        max_num_reqs: int,
        max_num_tokens: int,
        device: torch.device,
    ):
        self.max_num_reqs = max_num_reqs
        self.max_num_tokens = max_num_tokens
        self.device = device

23
        self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
24
        self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
25
26
        self.query_start_loc = torch.zeros(
            max_num_reqs + 1, dtype=torch.int32, device=device
Woosuk Kwon's avatar
Woosuk Kwon committed
27
        )
28
        self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
29
30
31
32
        # DCP: per-request local seq_lens buffer
        self.dcp_local_seq_lens = torch.zeros(
            max_num_reqs, dtype=torch.int32, device=device
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
36
37
38
39


@dataclass
class InputBatch:
    # batch_idx -> req_id
    req_ids: list[str]
    num_reqs: int
40
    num_reqs_after_padding: int
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
44

    # batch_idx -> req_state_idx
    idx_mapping: torch.Tensor
    idx_mapping_np: np.ndarray
45
46
    # Identical to idx_mapping except for spec decoding.
    expanded_idx_mapping: torch.Tensor
47
48
    # [total_num_logits] position within request for each logit
    expanded_local_pos: torch.Tensor
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
52
53
54
55

    # [num_reqs]
    # batch_idx -> num_scheduled_tokens
    num_scheduled_tokens: np.ndarray
    # sum(num_scheduled_tokens)
    num_tokens: int
    num_tokens_after_padding: int
56
    num_draft_tokens: int
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
59
60
61
62

    # [num_reqs + 1]
    query_start_loc: torch.Tensor
    query_start_loc_np: np.ndarray
    # [num_reqs]
    seq_lens: torch.Tensor
63
64
    # [num_reqs] CPU upper bound on seq_lens (see CommonAttentionMetadata).
    seq_lens_cpu_upper_bound: torch.Tensor
65
66
    # [num_reqs]
    dcp_local_seq_lens: torch.Tensor | None
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
71
72

    # [num_tokens_after_padding]
    input_ids: torch.Tensor
    # [num_tokens_after_padding]
    positions: torch.Tensor

73
    # [total_num_logits]
Woosuk Kwon's avatar
Woosuk Kwon committed
74
    logits_indices: torch.Tensor
75
76
    # [num_reqs + 1]
    cu_num_logits: torch.Tensor
77
    cu_num_logits_np: np.ndarray
Woosuk Kwon's avatar
Woosuk Kwon committed
78

79
80
81
    # Whether any requests in batch use structured output.
    has_structured_output_reqs: bool

Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
89
    @classmethod
    def make_dummy(
        cls,
        num_reqs: int,
        num_tokens: int,
        input_buffers: InputBuffers,
    ) -> "InputBatch":
        assert 0 < num_reqs <= num_tokens
90
91
        device = input_buffers.device

Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
        req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
        idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
        idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
95
        expanded_idx_mapping = idx_mapping
96
        expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
97

Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
        num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
        num_scheduled_tokens[-1] += num_tokens % num_reqs
        assert int(num_scheduled_tokens.sum()) == num_tokens

        # seq_len equals to query_len
103
104
        input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
        input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
105
        # Pad for full CUDA graph mode.
106
107
        input_buffers.seq_lens[num_reqs:] = 0
        seq_lens = input_buffers.seq_lens[:num_reqs]
Woosuk Kwon's avatar
Woosuk Kwon committed
108

109
110
111
        query_start_loc_np = np.empty(num_reqs + 1, dtype=np.int32)
        query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
112
        input_buffers.query_start_loc[:1] = 0
113
114
115
116
117
118
119
        torch.cumsum(
            seq_lens, dim=0, out=input_buffers.query_start_loc[1 : num_reqs + 1]
        )
        # Pad for full CUDA graph mode.
        input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
        query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]

120
121
122
        input_ids = input_buffers.input_ids[:num_tokens].zero_()
        positions = input_buffers.positions[:num_tokens].zero_()

Woosuk Kwon's avatar
Woosuk Kwon committed
123
        logits_indices = query_start_loc[1:] - 1
124
        cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
125
        cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
126
127
        # Dummy: seq_len == query_len (fresh-prefill shape).
        seq_lens_cpu_upper_bound = torch.from_numpy(num_scheduled_tokens.copy())
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
130
        return cls(
            req_ids=req_ids,
            num_reqs=num_reqs,
131
            num_reqs_after_padding=num_reqs,
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
134
            expanded_idx_mapping=expanded_idx_mapping,
135
            expanded_local_pos=expanded_local_pos,
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens,
139
            num_draft_tokens=0,
Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
            query_start_loc=query_start_loc,
            query_start_loc_np=query_start_loc_np,
            seq_lens=seq_lens,
143
            seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
144
            dcp_local_seq_lens=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
            input_ids=input_ids,
            positions=positions,
            logits_indices=logits_indices,
148
            cu_num_logits=cu_num_logits,
149
            cu_num_logits_np=cu_num_logits_np,
150
            has_structured_output_reqs=False,
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
        )


154
155
156
157
158
159
@triton.jit
def _prepare_prefill_inputs_kernel(
    input_ids_ptr,
    next_prefill_tokens_ptr,
    idx_mapping_ptr,
    query_start_loc_ptr,
160
161
    all_token_ids_ptr,
    all_token_ids_stride,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    prefill_lens_ptr,
    num_computed_tokens_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
    prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
    num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
    if num_computed >= prefill_len:
        # Not prefill.
        return

    query_start = tl.load(query_start_loc_ptr + batch_idx)
    query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
    query_len = query_end - query_start

178
    request_ptr = all_token_ids_ptr + req_state_idx * all_token_ids_stride
179
180
181
    for i in range(0, query_len, BLOCK_SIZE):
        block = i + tl.arange(0, BLOCK_SIZE)
        mask = block < query_len
182
        tokens = tl.load(request_ptr + num_computed + block, mask=mask)
183
184
185
186
        tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)

    next_pos = num_computed + query_len
    if next_pos < prefill_len:
187
        next_token = tl.load(request_ptr + next_pos)
188
        tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190


191
def prepare_prefill_inputs(
192
193
194
195
    input_ids: torch.Tensor,
    next_prefill_tokens: torch.Tensor,
    idx_mapping: torch.Tensor,
    query_start_loc: torch.Tensor,
196
    all_token_ids: torch.Tensor,
197
198
    prefill_len: torch.Tensor,
    num_computed_tokens: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
199
) -> None:
200
201
202
203
    num_reqs = idx_mapping.shape[0]
    _prepare_prefill_inputs_kernel[(num_reqs,)](
        input_ids,
        next_prefill_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
        idx_mapping,
205
        query_start_loc,
206
207
        all_token_ids,
        all_token_ids.stride(0),
208
209
210
        prefill_len,
        num_computed_tokens,
        BLOCK_SIZE=1024,
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
213
214
    )


@triton.jit
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def _prepare_pos_seq_lens_kernel(
    pos_ptr,
    seq_lens_ptr,
    idx_mapping_ptr,
    query_start_loc_ptr,
    num_computed_tokens_ptr,
    max_num_reqs,
    BLOCK_SIZE: tl.constexpr,
):
    req_id = tl.program_id(0)
    num_reqs = tl.num_programs(0) - 1
    if req_id == num_reqs:
        # Pad unused seq_lens as 0 for full CUDA graphs.
        for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
            block = i + tl.arange(0, BLOCK_SIZE)
            mask = block < max_num_reqs
            tl.store(seq_lens_ptr + block, 0, mask=mask)
        return

    req_state_idx = tl.load(idx_mapping_ptr + req_id)
    num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)

    start = tl.load(query_start_loc_ptr + req_id)
    end = tl.load(query_start_loc_ptr + req_id + 1)
    query_len = end - start

    seq_len = num_computed_tokens + query_len
    tl.store(seq_lens_ptr + req_id, seq_len)

    for i in tl.range(0, query_len, BLOCK_SIZE):
        block = i + tl.arange(0, BLOCK_SIZE)
        mask = block < query_len
        pos = num_computed_tokens + block
        tl.store(pos_ptr + start + block, pos, mask=mask)


def prepare_pos_seq_lens(
    idx_mapping: torch.Tensor,
    query_start_loc: torch.Tensor,
    num_computed_tokens: torch.Tensor,
    pos: torch.Tensor,
    seq_lens: torch.Tensor,
) -> None:
    num_reqs = idx_mapping.shape[0]
    # NOTE(woosuk): We do +1 because the last thread block is used
    # to pad unused seq_lens as 0 for full CUDA graphs.
    _prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
        pos,
        seq_lens,
        idx_mapping,
        query_start_loc,
        num_computed_tokens,
        seq_lens.shape[0],
        BLOCK_SIZE=1024,
    )


@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
    input_ids_ptr,
    idx_mapping_ptr,
276
    last_sampled_tokens_ptr,
Woosuk Kwon's avatar
Woosuk Kwon committed
277
278
279
    query_start_loc_ptr,
    seq_lens_ptr,
    prefill_len_ptr,
280
281
282
283
284
    draft_tokens_ptr,
    draft_tokens_stride,
    cu_num_logits_ptr,
    logits_indices_ptr,
    BLOCK_SIZE: tl.constexpr,
Woosuk Kwon's avatar
Woosuk Kwon committed
285
286
287
288
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    # Get the number of logits and draft tokens.
    cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
    cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
    num_logits = cu_num_logits_end - cu_num_logits_start
    num_draft_tokens = num_logits - 1

    # Compute the logits indices.
    block = tl.arange(0, BLOCK_SIZE)
    query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
    logits_start = query_end - num_logits
    tl.store(
        logits_indices_ptr + cu_num_logits_start + block,
        logits_start + block,
        mask=block < num_logits,
    )

Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
307
    seq_len = tl.load(seq_lens_ptr + batch_idx)
    prefill_len = tl.load(prefill_len_ptr + req_state_idx)
    if seq_len <= prefill_len:
308
        # Handling prefill tokens. No sampled or draft tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
        return

311
    # Write the last sampled token ID to input_ids.
312
    last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    tl.store(input_ids_ptr + query_end - num_logits, last_token_id)

    # Write the draft tokens (if any) to input_ids.
    if num_draft_tokens > 0:
        mask = block < num_draft_tokens
        draft_tokens = tl.load(
            draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
            mask=mask,
        )
        tl.store(
            input_ids_ptr + query_end - num_draft_tokens + block,
            draft_tokens,
            mask=mask,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328


329
def combine_sampled_and_draft_tokens(
Woosuk Kwon's avatar
Woosuk Kwon committed
330
331
    input_ids: torch.Tensor,
    idx_mapping: torch.Tensor,
332
    last_sampled_tokens: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
333
334
335
    query_start_loc: torch.Tensor,
    seq_lens: torch.Tensor,
    prefill_len: torch.Tensor,
336
337
338
    draft_tokens: torch.Tensor,
    cu_num_logits: torch.Tensor,
    num_logits: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
339
) -> torch.Tensor:
340
341
    # use idx_mapping.shape[0] for actual request count
    num_reqs = idx_mapping.shape[0]
342
343
344
345
346
347
348
    num_speculative_steps = draft_tokens.shape[-1]

    logits_indices = torch.empty(
        num_logits,
        dtype=torch.int64,
        device=input_ids.device,
    )
349
    _combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
        input_ids,
        idx_mapping,
352
        last_sampled_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
353
354
355
        query_start_loc,
        seq_lens,
        prefill_len,
356
357
358
359
360
361
362
        draft_tokens,
        draft_tokens.stride(0),
        cu_num_logits,
        logits_indices,
        # NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
        # in addition to all draft tokens.
        BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
Woosuk Kwon's avatar
Woosuk Kwon committed
363
    )
364
    return logits_indices
365
366


367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
@triton.jit
def _get_num_sampled_and_rejected_kernel(
    num_sampled_ptr,
    num_rejected_ptr,
    seq_lens_ptr,
    cu_num_logits_ptr,
    idx_mapping_ptr,
    prefill_len_ptr,
):
    batch_idx = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)

    seq_len = tl.load(seq_lens_ptr + batch_idx)
    prefill_len = tl.load(prefill_len_ptr + req_state_idx)
    is_chunked_prefilling = seq_len < prefill_len

    num_sampled = tl.load(num_sampled_ptr + batch_idx)
    num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled)
    tl.store(num_sampled_ptr + batch_idx, num_sampled)

    logits_start = tl.load(cu_num_logits_ptr + batch_idx)
    logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
    num_logits = logits_end - logits_start

    num_rejected = num_logits - num_sampled
    num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected)
    tl.store(num_rejected_ptr + batch_idx, num_rejected)


def get_num_sampled_and_rejected(
    num_sampled: torch.Tensor,
    seq_lens: torch.Tensor,
    cu_num_logits: torch.Tensor,
    idx_mapping: torch.Tensor,
    prefill_len: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    num_reqs = idx_mapping.shape[0]
    num_rejected = torch.empty_like(num_sampled)
    _get_num_sampled_and_rejected_kernel[(num_reqs,)](
        num_sampled,
        num_rejected,
        seq_lens,
        cu_num_logits,
        idx_mapping,
        prefill_len,
    )
    return num_sampled, num_rejected


416
@triton.jit
417
def _post_update_kernel(
418
419
    idx_mapping_ptr,
    num_computed_tokens_ptr,
420
    last_sampled_tokens_ptr,
421
422
    output_bin_counts_ptr,
    output_bin_counts_stride,
423
424
425
    sampled_tokens_ptr,
    sampled_tokens_stride,
    num_sampled_ptr,
426
    num_rejected_ptr,
427
    query_start_loc_ptr,
428
429
430
    all_token_ids_ptr,
    all_token_ids_stride,
    total_len_ptr,
431
432
433
434
):
    req_id = tl.program_id(0)
    req_state_idx = tl.load(idx_mapping_ptr + req_id)

435
    total_len = tl.load(total_len_ptr + req_state_idx)
436
437
438
439
440
441
    num_sampled = tl.load(num_sampled_ptr + req_id)
    if num_sampled > 0:
        token_id = tl.load(
            sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
        )
        tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
442
        tl.store(total_len_ptr + req_state_idx, total_len + num_sampled)
443

444
445
    for i in range(num_sampled):
        token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
446
447
448
449
        tl.store(
            all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
            token_id,
        )
450

451
452
453
454
455
456
457
458
459
        if output_bin_counts_ptr is not None:
            token_ptr = (
                output_bin_counts_ptr
                + req_state_idx * output_bin_counts_stride
                + token_id
            )
            count = tl.load(token_ptr)
            tl.store(token_ptr, count + 1)

460
461
462
    query_start = tl.load(query_start_loc_ptr + req_id)
    query_end = tl.load(query_start_loc_ptr + req_id + 1)
    query_len = query_end - query_start
463
    num_rejected = tl.load(num_rejected_ptr + req_id)
464
465

    num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
466
    num_computed += query_len - num_rejected
467
468
469
470
471
    tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)


def post_update(
    # [num_reqs]
472
    idx_mapping: torch.Tensor,
473
    # [max_num_reqs]
474
    num_computed_tokens: torch.Tensor,
475
476
    # [max_num_reqs]
    last_sampled_tokens: torch.Tensor,
477
    # [max_num_reqs, vocab_size]
478
    output_bin_counts: torch.Tensor | None,
479
480
481
482
    # [num_reqs, num_speculative_steps + 1]
    sampled_tokens: torch.Tensor,
    # [num_reqs]
    num_sampled: torch.Tensor,
483
484
    # [num_reqs]
    num_rejected: torch.Tensor,
485
    # [num_reqs + 1]
486
    query_start_loc: torch.Tensor,
487
488
489
490
    # [max_num_reqs, max_model_len]
    all_token_ids: torch.Tensor,
    # [max_num_reqs]
    total_len: torch.Tensor,
491
492
) -> None:
    num_reqs = idx_mapping.shape[0]
493
    _post_update_kernel[(num_reqs,)](
494
495
        idx_mapping,
        num_computed_tokens,
496
        last_sampled_tokens,
497
        output_bin_counts,
498
        output_bin_counts.stride(0) if output_bin_counts is not None else 0,
499
500
501
        sampled_tokens,
        sampled_tokens.stride(0),
        num_sampled,
502
        num_rejected,
503
        query_start_loc,
504
505
506
        all_token_ids,
        all_token_ids.stride(0),
        total_len,
507
        num_warps=1,
508
    )
509
510


511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
@triton.jit
def _post_update_pool_kernel(
    idx_mapping_ptr,
    num_computed_tokens_ptr,
    query_start_loc_ptr,
):
    batch_id = tl.program_id(0)
    query_start = tl.load(query_start_loc_ptr + batch_id)
    query_end = tl.load(query_start_loc_ptr + batch_id + 1)
    query_len = query_end - query_start

    req_state_idx = tl.load(idx_mapping_ptr + batch_id)
    num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
    tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)


def post_update_pool(
    # [num_reqs]
    idx_mapping: torch.Tensor,
    # [max_num_reqs]
    num_computed_tokens: torch.Tensor,
    # [num_reqs + 1]
    query_start_loc: torch.Tensor,
) -> None:
    num_reqs = idx_mapping.shape[0]
    _post_update_pool_kernel[(num_reqs,)](
        idx_mapping,
        num_computed_tokens,
        query_start_loc,
    )


543
544
545
546
@triton.jit
def _expand_idx_mapping_kernel(
    idx_mapping_ptr,
    expanded_idx_mapping_ptr,
547
    expanded_local_pos_ptr,
548
549
550
551
552
553
554
555
556
557
558
559
    cu_num_logits_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    req_idx = tl.program_id(0)
    start_idx = tl.load(cu_num_logits_ptr + req_idx)
    end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
    num_tokens = end_idx - start_idx

    block = tl.arange(0, BLOCK_SIZE)
    mask = block < num_tokens
    req_state_idx = tl.load(idx_mapping_ptr + req_idx)
    tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask)
560
    tl.store(expanded_local_pos_ptr + start_idx + block, block, mask=mask)
561
562
563
564
565
566
567


def expand_idx_mapping(
    idx_mapping: torch.Tensor,
    total_num_logits: int,
    cu_num_logits: torch.Tensor,
    max_expand_len: int,
568
) -> tuple[torch.Tensor, torch.Tensor]:
569
570
    num_reqs = idx_mapping.shape[0]
    expanded_idx_mapping = idx_mapping.new_empty(total_num_logits)
571
572
573
    expanded_local_pos = torch.empty(
        total_num_logits, dtype=torch.int32, device=idx_mapping.device
    )
574
575
576
    _expand_idx_mapping_kernel[(num_reqs,)](
        idx_mapping,
        expanded_idx_mapping,
577
        expanded_local_pos,
578
579
580
        cu_num_logits,
        BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
    )
581
    return expanded_idx_mapping, expanded_local_pos