tpu_model_runner.py 26.4 KB
Newer Older
1
import time
2
from typing import List, Mapping, Optional, Tuple
3
4
5
6
7
8
9
10

import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
11
                         MultiModalConfig, ParallelConfig, SchedulerConfig)
12
13
14
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
15
16
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
                             MultiModalInputs)
17
18
19
20
21
22
23
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
                           SamplerOutput, SequenceGroupMetadata,
                           SequenceOutput)
from vllm.utils import make_tensor_with_pad

logger = init_logger(__name__)

24
_PAD_SLOT_ID = -1  # NOTE(woosuk): In PyTorch XLA, index -1 is ignored.
25
26
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
27
28
29
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
30
31
32
33
34
35
36
37
38
39
40
41


class TPUModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
42
        multimodal_config: Optional[MultiModalConfig] = None,
43
        is_driver_worker: bool = False,
44
45
46
47
48
49
50
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
        self.cache_config = cache_config
        self.load_config = load_config
51
        self.multimodal_config = multimodal_config
52
        self.is_driver_worker = is_driver_worker
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

        self.block_size = self.cache_config.block_size
        self.max_num_blocks_per_seq = (self.model_config.max_model_len //
                                       self.block_size)
        self.block_tables = np.zeros(
            (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
            dtype=np.int32)
        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.cache_config.cache_dtype,
            self.block_size,
            False,
        )

71
72
73
74
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)

75
76
77
78
79
80
81
82
83
84
    def load_model(self) -> None:
        self.device = self.device_config.device

        model = get_model(
            model_config=self.model_config,
            load_config=self.load_config,
            device_config=self.device_config,
            parallel_config=self.parallel_config,
            cache_config=self.cache_config,
            scheduler_config=self.scheduler_config,
85
            multimodal_config=self.multimodal_config,
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            lora_config=None,
        )
        xm.wait_device_ops()

        model = ModelWrapper(model)
        self.model = torch.compile(model, backend="openxla", fullgraph=True)

    def _dummy_run(
        self,
        batch_size: int,
        seq_len: int,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        is_prompt: bool,
    ) -> None:
        if is_prompt:
            seq_len = (seq_len + 15) // 16 * 16
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            attn_metadata = self.attn_backend.make_metadata(
                num_prefills=batch_size,
                num_prefill_tokens=batch_size * seq_len,
                num_decode_tokens=0,
                slot_mapping=slot_mapping,
                block_tables=None,
                context_lens=None,
            )
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
        else:
            assert seq_len == 1
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            block_tables = torch.zeros(
                (batch_size, self.max_num_blocks_per_seq),
                dtype=torch.int32,
                device=self.device)
            context_lens = torch.ones((batch_size, ),
                                      dtype=torch.int32,
                                      device=self.device)
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
            attn_metadata = self.attn_backend.make_metadata(
                num_prefills=0,
                num_prefill_tokens=0,
                num_decode_tokens=batch_size * seq_len,
                slot_mapping=slot_mapping,
                block_tables=block_tables,
                context_lens=context_lens,
            )
        t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
        p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)

        # Dummy run.
155
        num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
156
        self.model(token_ids, position_ids, kv_caches, attn_metadata,
157
                   input_lens, None, t, p, num_samples)
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    def warmup_model(
        self,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
    ) -> None:
        # Prefill
        logger.info("Compiling the model with different input shapes...")
        start = time.time()
        for batch_size in [1]:
            seq_len = 16
            while True:
                self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
                xm.wait_device_ops()
                logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)

                if seq_len >= self.model_config.max_model_len:
                    break
                num_tokens = batch_size * seq_len
                if num_tokens >= self.scheduler_config.max_num_batched_tokens:
                    break
                seq_len = seq_len * 2

        end = time.time()
        logger.info("Compilation for prefill done in %.2f s.", end - start)

        # Decode
        start = time.time()
        seq_len = 1
        batch_size = 1
        while True:
            self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
            xm.wait_device_ops()
            logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)

            if batch_size >= self.scheduler_config.max_num_seqs:
                break
            batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2

        end = time.time()
        logger.info("Compilation for decode done in %.2f s.", end - start)

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
202
203
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
               Mapping[str, BatchedTensors]]:
204
205
206
207
208
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        prompt_lens: List[int] = []
        slot_mapping: List[List[int]] = []
209
        multi_modal_inputs_list: List[MultiModalInputs] = []
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            # Could include output tokens when a request is preempted.
            prompt_tokens = seq_data.get_token_ids()
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)

            input_tokens.append(prompt_tokens)
            input_positions.append(list(range(prompt_len)))

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
            slot_mapping.append([])
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping[-1].append(slot)

235
236
237
238
239
            mm_data = seq_group_metadata.multi_modal_data
            if mm_data:
                mm_kwargs = self.multi_modal_input_mapper(mm_data)
                multi_modal_inputs_list.append(mm_kwargs)

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        assert len(prompt_lens) > 0
        num_prefills = len(prompt_lens)
        num_prefill_tokens = sum(prompt_lens)

        # Add paddings to make the shape [batch_size, max_prompt_len] where
        # max_prompt_len is smallest power of 2 that is greater than or equal
        # to the maximum prompt length.
        # We need the 2D input shape because the Pallas FlashAttention kernel
        # does not support packed 1D inputs.
        # We pad the seq_len to powers of 2 to reduce the compilation overhead.
        max_prompt_len = _get_padded_prefill_len(max(prompt_lens))
        input_tokens = make_tensor_with_pad(input_tokens,
                                            max_prompt_len,
                                            pad=0,
                                            dtype=torch.int32,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               max_prompt_len,
                                               pad=0,
                                               dtype=torch.int32,
                                               device=self.device)
        slot_mapping = make_tensor_with_pad(slot_mapping,
                                            max_prompt_len,
                                            pad=_PAD_SLOT_ID,
                                            dtype=torch.int64,
                                            device=self.device)
        prompt_lens = torch.tensor(prompt_lens,
                                   dtype=torch.int32,
                                   device=self.device)
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,  # NOTE: This is not used.
            num_decode_tokens=0,
            slot_mapping=slot_mapping,
            block_tables=None,
            context_lens=None,
        )
277
278
279
280
281
282

        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
                                                    device=self.device)

        return (input_tokens, input_positions, attn_metadata, prompt_lens,
                multi_modal_kwargs)
283
284
285
286

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
287
288
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
               Mapping[str, BatchedTensors]]:
289
290
291
292
293
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
        context_lens: List[int] = []
294
        multi_modal_inputs_list: List[MultiModalInputs] = []
295

296
297
        batch_idx = 0
        for seq_group_metadata in seq_group_metadata_list:
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            assert not seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
312
313
                self.block_tables[batch_idx, :len(block_table)] = block_table
                batch_idx += 1
314
315
316
317
318
319

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append([slot])

320
321
322
323
324
            mm_data = seq_group_metadata.multi_modal_data
            if mm_data:
                mm_kwargs = self.multi_modal_input_mapper(mm_data)
                multi_modal_inputs_list.append(mm_kwargs)

325
326
        batch_size = _get_padded_batch_size(batch_idx)
        num_paddings = batch_size - batch_idx
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        input_tokens = input_tokens + [[0]] * num_paddings
        input_positions = input_positions + [[0]] * num_paddings
        slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
        context_lens = context_lens + [0] * num_paddings

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
                                    device=self.device)
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int32,
                                    device=self.device)
        block_tables = torch.tensor(self.block_tables[:batch_size],
                                    dtype=torch.int32,
                                    device=self.device)
        input_lens = torch.tensor([1] * batch_size,
                                  dtype=torch.int32,
                                  device=self.device)
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            context_lens=context_lens,
        )
358
359
360
361
362
363

        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
                                                    device=self.device)

        return (input_tokens, input_positions, attn_metadata, input_lens,
                multi_modal_kwargs)
364
365
366
367
368

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        padded_batch_size: int,
369
    ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
370
371
372
        assert len(seq_group_metadata_list) > 0
        t = []
        p = []
373
        best_of = []
374
375
        for seq_group_metadata in seq_group_metadata_list:
            sampling_params = seq_group_metadata.sampling_params
376
377
            # NOTE(woosuk): Here we mimic argmax sampling by applying a very
            # low temperature. This is not accurate.
378
379
            t.append(sampling_params.temperature
                     if sampling_params.temperature >= 1e-5 else 1e-5)
380
381
382
383
            if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
                raise NotImplementedError(
                    "Top-p sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
384
            p.append(sampling_params.top_p)
385
386
387
388
            if sampling_params.top_k != -1:
                raise NotImplementedError(
                    "Top-k sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
389
            if sampling_params.best_of > _MAX_NUM_SAMPLES:
390
                raise NotImplementedError(
391
                    f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
392
                    "backend.")
393
            best_of.append(sampling_params.best_of)
394
395
396
397
398
399
400
401
402
403
404
            if sampling_params.use_beam_search:
                raise NotImplementedError(
                    "Beam search is not supported by the TPU backend.")
            if sampling_params.logprobs is not None:
                raise NotImplementedError(
                    "logprobs is not currently supported by the TPU backend.")
            if sampling_params.prompt_logprobs is not None:
                raise NotImplementedError(
                    "prompt_logprobs is not currently supported by the TPU "
                    "backend.")

405
406
407
408
409
410
411
            # Repeat the sampling params if the seq group has multiple seqs.
            num_seqs = len(seq_group_metadata.seq_data)
            t += [t[-1]] * (num_seqs - 1)
            p += [p[-1]] * (num_seqs - 1)
            best_of += [best_of[-1]] * (num_seqs - 1)

        num_paddings = padded_batch_size - len(t)
412
413
414
415
416
        t += [1.0] * num_paddings
        p += [1.0] * num_paddings

        t = torch.tensor(t, dtype=torch.float32, device=self.device)
        p = torch.tensor(p, dtype=torch.float32, device=self.device)
417
        return t, p, best_of
418

419
    def _execute_model(
420
        self,
421
422
423
424
        seq_group_metadata_list: List[SequenceGroupMetadata],
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
    ) -> List[CompletionSequenceGroupOutput]:
        # Prepare inputs.
425
426
427
        assert len(seq_group_metadata_list) > 0
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
428
429
        is_prompt = seq_group_metadata_list[0].is_prompt
        if is_prompt:
430
431
432
433
            inputs = self._prepare_prompt(seq_group_metadata_list)
        else:
            inputs = self._prepare_decode(seq_group_metadata_list)
        padded_batch_size = inputs[0].shape[0]
434
435
436
        t, p, best_of = self._prepare_sample(seq_group_metadata_list,
                                             padded_batch_size)
        num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
437

438
        # Execute the model.
439
        next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
440
                                    *inputs[2:], t, p, num_samples)
441
        # Retrieve the outputs to CPU.
442
443
        next_token_ids = next_token_ids.cpu().tolist()

444
445
446
        # NOTE(woosuk): Minimal code to construct the sampler outputs.
        # The TPU backend does not reuse the sampler, since the TPU backend
        # does not support the advanced sampling parameters such as logprobs.
447
448
        zero_logprob = Logprob(0.0)
        batch_idx = 0
449
450
451
452
        sampler_outputs = []
        for seq_group_metadata in seq_group_metadata_list:
            seq_outputs = []
            seq_ids = list(seq_group_metadata.seq_data.keys())
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
            if is_prompt:
                assert len(seq_ids) == 1
                seq_id = seq_ids[0]
                for i in range(best_of[batch_idx]):
                    next_token_id = next_token_ids[batch_idx][i]
                    seq_outputs.append(
                        SequenceOutput(seq_id, next_token_id,
                                       {next_token_id: zero_logprob}))
                batch_idx += 1
            else:
                for seq_id in seq_ids:
                    next_token_id = next_token_ids[batch_idx][0]
                    seq_outputs.append(
                        SequenceOutput(seq_id, next_token_id,
                                       {next_token_id: zero_logprob}))
                    batch_idx += 1
469
470
471
472
473
474
475
476
            sampler_outputs.append(
                CompletionSequenceGroupOutput(seq_outputs, None))
        return sampler_outputs

    def execute_model(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
477
478
479
480
481
482
        num_steps: int = 1,
    ) -> List[SamplerOutput]:
        if num_steps > 1:
            raise ValueError(
                "TPUModelRunner does not support multi-step execution.")

483
        assert seq_group_metadata_list is not None
484
        assert len(seq_group_metadata_list) > 0
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        if seq_group_metadata_list[0].is_prompt:
            # NOTE(woosuk): To reduce the compilation time, we only compile the
            # prefill inputs with batch size 1. Because the scheduler is not
            # aware of this limitation, we need to handle batch size > 1
            # internally by calling the model multiple times and concatenating
            # the outputs.
            # FIXME(woosuk): This is a temporary hack to not change the existing
            # scheduler. We need to fix this in the future.
            sampler_outputs = []
            for seq_group_metadata in seq_group_metadata_list:
                sampler_outputs += self._execute_model([seq_group_metadata],
                                                       kv_caches)
        else:
            sampler_outputs = self._execute_model(seq_group_metadata_list,
                                                  kv_caches)
500
        return [SamplerOutput(sampler_outputs)]
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515


class ModelWrapper(nn.Module):

    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model.eval()

    def forward(
        self,
        token_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
        attn_metadata: AttentionMetadata,
        input_lens: torch.Tensor,
516
        multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
517
518
        t: torch.Tensor,
        p: torch.Tensor,
519
        num_samples: int,
520
521
522
523
524
525
526
527
528
529
    ) -> torch.Tensor:
        """Executes the forward pass of the model and samples the next token.

        Args:
            token_ids: The input token IDs of shape [batch_size, seq_len].
            position_ids: The input position IDs of shape [batch_size, seq_len].
            kv_caches: The key and value caches. They can be None during the
                memory profiling at initialization.
            attn_metadata: The Pallas attention metadata.
            input_lens: The actual input lengths of shape [batch_size].
530
531
            multi_modal_kwargs: Keyword arguments from multi-modal data to
                pass to the model.
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            t: The sampling temperature of shape [batch_size].
            p: The top-p probability of shape [batch_size].
        """
        batch_size, seq_len = token_ids.shape
        # Calculate the positions to sample from.
        base_indicies = torch.arange(
            batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
        logits_indices = base_indicies + input_lens - 1

        # FIXME(woosuk): This is a temporary hack to avoid using the existing
        # sampler and sampling metadata.
        sampling_metadata = SamplingMetadata(
            seq_groups=[],
            selected_token_indices=logits_indices,
            categorized_sample_indices={},
            num_prompts=attn_metadata.num_prefills,
        )

        # Skip this in memory profiling at initialization.
        if kv_caches[0][0] is not None:
            # index_copy_(slot_mapping) only works when the inserted dimension
            # is 0. However, the KV cache in the Pallas backend has the shape
            # [num_kv_heads, num_blocks, block_size, head_size]. To make it
            # work, we need to flatten the first three dimensions and modify
            # the slot_mapping accordingly.
            num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
            slot_mapping = attn_metadata.slot_mapping
            slot_mapping = slot_mapping.flatten()
            head_indicies = torch.arange(0,
                                         num_kv_heads,
                                         device=slot_mapping.device,
                                         dtype=slot_mapping.dtype)
            head_indicies *= block_size * num_blocks
            slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
                -1, num_kv_heads)
            slot_mapping = slot_mapping + head_indicies.view(1, -1)
            slot_mapping = slot_mapping.flatten()
            attn_metadata.slot_mapping = slot_mapping

        hidden_states = self.model(
            token_ids,
            position_ids,
            kv_caches,
            attn_metadata,
576
            **(multi_modal_kwargs or {}),
577
578
579
580
581
        )
        hidden_states = hidden_states.flatten(0, 1)
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        logits = logits / t.unsqueeze(dim=1)
582
583
        if _ENABLE_TOP_P:
            logits = _apply_top_p(logits, p.unsqueeze(dim=1))
584
        probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
585
586
587
        next_token_ids = torch.multinomial(probs,
                                           num_samples,
                                           replacement=True)
588
589
590
591
592
593
594
595
596
597
598
599
600
        return next_token_ids


def _get_padded_prefill_len(x: int) -> int:
    # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
    # length to be a multiple of 16. We pad the prompt length to the nearest
    # multiple of 16. This is also good for performance.
    if x <= 16:
        return 16
    return 1 << (x - 1).bit_length()


def _get_padded_batch_size(batch_size: int) -> int:
601
602
603
604
    # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
    # To meet this requirement in the simplest way, we set the minimal batch
    # size to 8.
    if batch_size <= 8:
605
606
607
608
609
610
611
612
613
614
615
616
        return 8
    else:
        return ((batch_size + 15) // 16) * 16


def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
    logits_sorted = torch.sort(logits, dim=-1, descending=True).values
    sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
    cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
    cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
    logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
    return logits