tpu_model_runner.py 39 KB
Newer Older
1
import enum
2
import time
3
from dataclasses import dataclass
4
5
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Type, Union)
6
from unittest.mock import patch
7
8
9
10
11

import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
12
import torch_xla.runtime as xr
13
14

from vllm.attention import AttentionMetadata, get_attn_backend
15
from vllm.config import VllmConfig
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.sampler import SamplerOutput
18
19
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
21
                           Logprob, SequenceGroupMetadata, SequenceOutput)
22
23
24
25
26
27
28
from vllm.worker.model_runner_base import (
    ModelRunnerBase, ModelRunnerInputBase,
    _add_attn_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
29
30
31

logger = init_logger(__name__)

32
33
34
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
35
36
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
37
38
39
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
40
41


42
43
44
45
46
47
48
49
50
class ExecutionMode(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()
    PREFIX_PREFILL = enum.auto()

    def is_prefill(self) -> bool:
        return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)


51
52
53
54
55
56
57
58
59
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
    token_ids: torch.Tensor
    position_ids: torch.Tensor
    attn_metadata: AttentionMetadata
    input_lens: torch.Tensor
    t: torch.Tensor
    p: torch.Tensor
    num_samples: int
60
    n: List[int]
61
    seq_groups: List[List[int]]
62
63
    is_first_multi_step: bool = True
    is_last_step: bool = True
64
    virtual_engine: int = 0
65
    async_callback: Optional[Callable] = None
66
67
68
69
70
71
72
73
74
75

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "token_ids": self.token_ids,
            "position_ids": self.position_ids,
            "input_lens": self.input_lens,
            "t": self.t,
            "p": self.p,
            "num_samples": self.num_samples,
76
            "n": self.n,
77
            "seq_groups": self.seq_groups,
78
79
            "is_first_multi_step": self.is_first_multi_step,
            "is_last_step": self.is_last_step,
80
            "virtual_engine": self.virtual_engine,
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["ModelInputForTPU"],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForTPU":
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
98
99
100

    def __init__(
        self,
101
        vllm_config: VllmConfig,
102
        is_driver_worker: bool = False,
103
    ):
104
        ModelRunnerBase.__init__(self, vllm_config=vllm_config)
105
        self.is_driver_worker = is_driver_worker
106
107
108
109
110
111
112
113
114
115
116
117

        self.block_size = self.cache_config.block_size
        self.max_num_blocks_per_seq = (self.model_config.max_model_len //
                                       self.block_size)
        self.block_tables = np.zeros(
            (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
            dtype=np.int32)
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.cache_config.cache_dtype,
            self.block_size,
118
            self.model_config.is_attention_free,
119
120
            False,
        )
121
        self.cached_step_outputs: List[torch.Tensor] = []
122

123
124
125
126
127
128
129
130
131
        smem_size = 512 * 1024
        block_table_size = 4 * self.block_tables.size
        if block_table_size >= smem_size:
            logger.warning(
                "The max_model_len (%d) is too large. This may degrade the "
                "performance due to the insufficient smem size. Consider "
                "setting --max-model-len to a smaller value.",
                self.model_config.max_model_len)

132
133
134
    def load_model(self) -> None:
        self.device = self.device_config.device

135
136
137
138
139
140
141
142
143
        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
144
        xm_tp_rank = xr.global_ordinal()
145
146
147
148
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
149
            model = get_model(vllm_config=self.vllm_config)
150
        model = model.eval()
151
        xm.wait_device_ops()
152
153
154
155
156
        model = ModelWrapper(model)
        self.model = torch.compile(model,
                                   backend="openxla",
                                   fullgraph=True,
                                   dynamic=False)
157
158
159
160
161
162

    def _dummy_run(
        self,
        batch_size: int,
        seq_len: int,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
163
        exec_mode: ExecutionMode,
164
    ) -> None:
165
166
        exec_mode = ExecutionMode(exec_mode)
        if exec_mode.is_prefill():
167
168
169
170
171
172
173
174
175
176
177
178
179
            seq_len = (seq_len + 15) // 16 * 16
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            if exec_mode == ExecutionMode.PREFILL:
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=batch_size,
                    num_prefill_tokens=batch_size * seq_len,
                    num_decode_tokens=0,
                    slot_mapping=slot_mapping,
                    multi_modal_placeholder_index_maps=None,
                    block_tables=None,
                    context_lens=None,
                    effective_query_lens=None,
                )
            else:
                context_lens = torch.ones((batch_size, ),
                                          dtype=torch.int32,
                                          device=self.device)
                block_tables = torch.tensor(self.block_tables[:batch_size],
                                            dtype=torch.int32,
                                            device=self.device)
                effective_query_lens = torch.ones_like(context_lens)
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=batch_size,
                    num_prefill_tokens=batch_size * seq_len,
                    num_decode_tokens=0,
                    slot_mapping=slot_mapping,
                    multi_modal_placeholder_index_maps=None,
                    block_tables=block_tables,
                    context_lens=context_lens,
                    effective_query_lens=effective_query_lens,
                )
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        else:
            assert seq_len == 1
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            block_tables = torch.zeros(
                (batch_size, self.max_num_blocks_per_seq),
                dtype=torch.int32,
                device=self.device)
            context_lens = torch.ones((batch_size, ),
                                      dtype=torch.int32,
                                      device=self.device)
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
            attn_metadata = self.attn_backend.make_metadata(
                num_prefills=0,
                num_prefill_tokens=0,
                num_decode_tokens=batch_size * seq_len,
                slot_mapping=slot_mapping,
235
                multi_modal_placeholder_index_maps=None,
236
237
238
239
240
                block_tables=block_tables,
                context_lens=context_lens,
            )
        t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
        p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
241
        num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
242

243
244
245
246
247
248
249
        # NOTE(woosuk): There are two stages of compilation: torch.compile and
        # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
        # overhead by reusing the FX graph for different shapes.
        # However, the XLA graph will still require static shapes and needs to
        # be re-compiled for every different shapes. This overhead is inevitable
        # in the first run, but can be skipped afterwards as we cache the XLA
        # graphs in the disk (VLLM_XLA_CACHE_PATH).
250
        if exec_mode.is_prefill():
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            # Prefll
            torch._dynamo.mark_dynamic(token_ids, 1)
            torch._dynamo.mark_dynamic(position_ids, 1)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
        else:
            # Decode
            torch._dynamo.mark_dynamic(token_ids, 0)
            torch._dynamo.mark_dynamic(position_ids, 0)
            torch._dynamo.mark_dynamic(input_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
            torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
            torch._dynamo.mark_dynamic(t, 0)
            torch._dynamo.mark_dynamic(p, 0)
265
        # Dummy run.
266
267
        self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
                   num_samples, kv_caches)
268
269
270
271
272
273
274
275
276
277

    def warmup_model(
        self,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
    ) -> None:
        # Prefill
        logger.info("Compiling the model with different input shapes...")
        start = time.time()
        for batch_size in [1]:
            seq_len = 16
278
279
280
281
282
            while seq_len <= self.model_config.max_model_len:
                self._dummy_run(batch_size,
                                seq_len,
                                kv_caches,
                                exec_mode=ExecutionMode.PREFILL)
283
284
285
286
287
288
289
290
291
292
                xm.wait_device_ops()
                logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
                num_tokens = batch_size * seq_len
                if num_tokens >= self.scheduler_config.max_num_batched_tokens:
                    break
                seq_len = seq_len * 2

        end = time.time()
        logger.info("Compilation for prefill done in %.2f s.", end - start)

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        # Prefix prefill
        if self.cache_config.enable_prefix_caching:
            logger.info("Compiling the model with different input shapes for "
                        "prefix prefill...")
            start = time.time()
            for batch_size in [1]:
                seq_len = 16
                while seq_len <= self.model_config.max_model_len:
                    self._dummy_run(batch_size,
                                    seq_len,
                                    kv_caches,
                                    exec_mode=ExecutionMode.PREFIX_PREFILL)
                    xm.wait_device_ops()
                    logger.info("batch_size: %d, seq_len: %d", batch_size,
                                seq_len)
                    num_tokens = batch_size * seq_len
                    if (num_tokens >=
                            self.scheduler_config.max_num_batched_tokens):
                        break
                    seq_len = seq_len * 2
            end = time.time()
            logger.info("Compilation for prefix prefill done in %.2f s.",
                        end - start)

317
318
319
        # Decode
        start = time.time()
        seq_len = 1
320
        batch_size = 8  # Must be in sync with _get_padded_batch_size()
321
        while True:
322
323
324
325
            self._dummy_run(batch_size,
                            seq_len,
                            kv_caches,
                            exec_mode=ExecutionMode.DECODE)
326
327
328
329
330
331
332
333
334
335
336
337
338
            xm.wait_device_ops()
            logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)

            if batch_size >= self.scheduler_config.max_num_seqs:
                break
            batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2

        end = time.time()
        logger.info("Compilation for decode done in %.2f s.", end - start)

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
339
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
340
        assert len(seq_group_metadata_list) > 0
341
342
        input_tokens: List[int] = []
        input_positions: List[int] = []
343
        prompt_lens: List[int] = []
344
        context_lens: List[int] = []
345
        slot_mapping: List[int] = []
346

347
348
        for batch_idx, seq_group_metadata in enumerate(
                seq_group_metadata_list):
349
350
351
352
353
354
355
356
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            # Could include output tokens when a request is preempted.
            prompt_tokens = seq_data.get_token_ids()
357
358
359
360
361
362
363
364
365
366
            seq_len = len(prompt_tokens)

            num_computed_blocks = len(seq_group_metadata.computed_block_nums)
            num_computed_tokens = num_computed_blocks * self.block_size
            if num_computed_tokens > 0:
                prompt_tokens = prompt_tokens[num_computed_tokens:]
                context_lens.append(seq_len)
            else:
                context_lens.append(0)

367
368
369
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)

370
            input_tokens.extend(prompt_tokens)
371
            input_positions.extend(range(num_computed_tokens, seq_len))
372
373
374

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
375
            for i in range(num_computed_tokens, seq_len):
376
377
378
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
379
                slot_mapping.append(slot)
380
381
            if num_computed_tokens > 0:
                self.block_tables[batch_idx, :len(block_table)] = block_table
382
383
384
385
386
387
388
389
390
391
392
393

            # Add paddings to EACH prompt to the smallest power of 2 that is
            # greater than or equal to the prompt length.
            # We pad the seq_len to reduce the compilation overhead.
            # We execute each prompt individually (i.e., with batch_size 1)
            # because the FlashAttention kernel does not support ragged inputs.
            # TODO(woosuk): Use SplashAttention to support ragged inputs.
            padded_prompt_len = _get_padded_prefill_len(prompt_len)
            num_paddings = padded_prompt_len - prompt_len
            input_tokens += [0] * num_paddings
            input_positions += [0] * num_paddings
            slot_mapping += [_PAD_SLOT_ID] * num_paddings
394
395
396

        assert len(prompt_lens) > 0
        num_prefills = len(prompt_lens)
397
398
399
400
401
402
403
404
405
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
                                    device="cpu")
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
                                       device="cpu")
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
                                    device="cpu")
406
407
        prompt_lens = torch.tensor(prompt_lens,
                                   dtype=torch.int32,
408
                                   device="cpu")
409
410
411
412
413
414
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int32,
                                    device="cpu")
        block_tables = torch.tensor(self.block_tables[:num_prefills],
                                    dtype=torch.int32,
                                    device="cpu")
415
416
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=num_prefills,
417
            num_prefill_tokens=0,  # NOTE: This is not used.
418
419
            num_decode_tokens=0,
            slot_mapping=slot_mapping,
420
            multi_modal_placeholder_index_maps=None,
421
422
423
            block_tables=block_tables,
            context_lens=context_lens,
            effective_query_lens=prompt_lens,
424
        )
425
        return input_tokens, input_positions, attn_metadata, prompt_lens
426
427
428
429

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
430
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
431
432
433
434
435
436
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
        context_lens: List[int] = []

437
438
        batch_idx = 0
        for seq_group_metadata in seq_group_metadata_list:
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            assert not seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
453
454
                self.block_tables[batch_idx, :len(block_table)] = block_table
                batch_idx += 1
455
456
457
458
459
460

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append([slot])

461
462
        batch_size = _get_padded_batch_size(batch_idx)
        num_paddings = batch_size - batch_idx
463
464
465
466
467
468
469
        input_tokens = input_tokens + [[0]] * num_paddings
        input_positions = input_positions + [[0]] * num_paddings
        slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
        context_lens = context_lens + [0] * num_paddings

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
470
                                    device="cpu")
471
472
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
473
                                       device="cpu")
474
475
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
476
                                    device="cpu")
477
478
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int32,
479
                                    device="cpu")
480
481
        block_tables = torch.tensor(self.block_tables[:batch_size],
                                    dtype=torch.int32,
482
                                    device="cpu")
483
484
        input_lens = torch.tensor([1] * batch_size,
                                  dtype=torch.int32,
485
                                  device="cpu")
486
487
488
489
490
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            slot_mapping=slot_mapping,
491
            multi_modal_placeholder_index_maps=None,
492
493
494
            block_tables=block_tables,
            context_lens=context_lens,
        )
495
        return input_tokens, input_positions, attn_metadata, input_lens
496
497
498
499
500

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        padded_batch_size: int,
501
    ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
502
503
504
        assert len(seq_group_metadata_list) > 0
        t = []
        p = []
505
        n = []
506
507
        for seq_group_metadata in seq_group_metadata_list:
            sampling_params = seq_group_metadata.sampling_params
508
            t.append(sampling_params.temperature)
509
510
511
512
            if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
                raise NotImplementedError(
                    "Top-p sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
513
            p.append(sampling_params.top_p)
514
515
516
517
            if sampling_params.top_k != -1:
                raise NotImplementedError(
                    "Top-k sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
518
            if sampling_params.n > _MAX_NUM_SAMPLES:
519
                raise NotImplementedError(
520
                    f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
521
                    "backend.")
522
            n.append(sampling_params.n)
523
524
525
526
527
528
529
530
            if sampling_params.logprobs is not None:
                raise NotImplementedError(
                    "logprobs is not currently supported by the TPU backend.")
            if sampling_params.prompt_logprobs is not None:
                raise NotImplementedError(
                    "prompt_logprobs is not currently supported by the TPU "
                    "backend.")

531
532
533
534
            # Repeat the sampling params if the seq group has multiple seqs.
            num_seqs = len(seq_group_metadata.seq_data)
            t += [t[-1]] * (num_seqs - 1)
            p += [p[-1]] * (num_seqs - 1)
535
            n += [n[-1]] * (num_seqs - 1)
536
537

        num_paddings = padded_batch_size - len(t)
538
539
540
        t += [1.0] * num_paddings
        p += [1.0] * num_paddings

541
542
        t = torch.tensor(t, dtype=torch.float32, device="cpu")
        p = torch.tensor(p, dtype=torch.float32, device="cpu")
543
        return t, p, n
544

545
    def prepare_model_input(
546
        self,
547
        seq_group_metadata_list: List[SequenceGroupMetadata],
548
549
550
551
552
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None,
    ) -> ModelInputForTPU:
        del finished_requests_ids  # Unused.
        assert virtual_engine == 0
553
554
555
        assert len(seq_group_metadata_list) > 0
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
556
557
        is_prompt = seq_group_metadata_list[0].is_prompt
        if is_prompt:
558
559
560
            inputs = self._prepare_prompt(seq_group_metadata_list)
        else:
            inputs = self._prepare_decode(seq_group_metadata_list)
561
562
        input_tokens, input_positions, attn_metadata, input_lens = inputs
        padded_batch_size = input_tokens.shape[0]
563
564
        t, p, n = self._prepare_sample(seq_group_metadata_list,
                                       padded_batch_size)
565
        num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
566

567
568
569
570
571
        seq_groups = [
            list(metadata.seq_data.keys())
            for metadata in seq_group_metadata_list
        ]
        return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
572
                                input_lens, t, p, num_samples, n, seq_groups)
573
574
575
576
577
578
579

    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
        model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=self.attn_backend)
        return model_input

580
    @torch.no_grad()
581
582
583
    def execute_model(
        self,
        model_input: ModelInputForTPU,
584
        kv_caches: Optional[List[Any]],
585
586
587
588
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> List[SamplerOutput]:
        assert intermediate_tensors is None
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        if not model_input.is_first_multi_step:
            if not model_input.is_last_step:
                return []

            use_async_out_proc = model_input.async_callback is not None
            sampler_outputs = []
            num_outputs = len(self.cached_step_outputs)
            for i in range(num_outputs):
                next_token_ids = self.cached_step_outputs.pop(0)
                next_token_ids = next_token_ids.cpu().tolist()
                sampler_output = _make_decode_output(next_token_ids,
                                                     model_input.seq_groups)
                sampler_outputs.append(sampler_output)

                if i < num_outputs - 1 and use_async_out_proc:
                    assert model_input.async_callback is not None
                    ctx = model_input.async_callback.keywords[  # type: ignore
                        "ctx"]
                    ctx.append_output(
                        outputs=[sampler_output],
                        seq_group_metadata_list=ctx.seq_group_metadata_list,
                        scheduler_outputs=ctx.scheduler_outputs,
                        is_async=False,
612
613
                        is_last_step=False,
                        is_first_step_output=i == 0)
614
615
616
617
618
619
620
                    model_input.async_callback()
            if use_async_out_proc:
                return [sampler_outputs[-1]]
            else:
                return sampler_outputs

        is_prompt = model_input.attn_metadata.num_prefills > 0
621
        if is_prompt:
622
            assert num_steps == 1
623
624
625
626
627
            # NOTE(woosuk): Since the FlashAttention kernel does not support
            # ragged inputs, we split the prompts into different batches and
            # process them separately. This is a temporary hack that should be
            # optimized by using SplashAttention.
            orig_slot_mapping = model_input.attn_metadata.slot_mapping
628
629
630
631
            orig_block_tables = model_input.attn_metadata.block_tables
            orig_context_lens = model_input.attn_metadata.context_lens
            orig_effective_query_lens = \
                model_input.attn_metadata.effective_query_lens
632
633
            batch_size = model_input.input_lens.shape[0]
            start_idx = 0
634
            next_token_ids = []
635
636
637
638
639
640
            for i in range(batch_size):
                # Get the actual prefill_len.
                prefill_len = model_input.input_lens[i:i + 1].item()
                prefill_len = _get_padded_prefill_len(prefill_len)
                end_idx = start_idx + prefill_len

641
642
643
644
645
646
647
648
649
                token_ids = model_input.token_ids[None, start_idx:end_idx].to(
                    self.device)
                position_ids = model_input.position_ids[None,
                                                        start_idx:end_idx].to(
                                                            self.device)
                attn_metadata = model_input.attn_metadata
                attn_metadata.num_prefills = 1
                attn_metadata.slot_mapping = orig_slot_mapping[
                    None, start_idx:end_idx].to(self.device)
650
651
652
653
654
655
656
657
658
659
660
                if orig_context_lens[i].item() > 0:
                    attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
                        self.device)
                    attn_metadata.block_tables = orig_block_tables[
                        i].unsqueeze(0).to(self.device)
                    attn_metadata.effective_query_lens = \
                        orig_effective_query_lens[i:i + 1].to(self.device)
                else:
                    attn_metadata.context_lens = None
                    attn_metadata.block_tables = None
                    attn_metadata.effective_query_lens = None
661
662
663
                input_lens = model_input.input_lens[i:i + 1].to(self.device)
                t = model_input.t[i:i + 1].to(self.device)
                p = model_input.p[i:i + 1].to(self.device)
664
665
                output_token_ids = self.model(token_ids, position_ids,
                                              attn_metadata, input_lens, t, p,
666
                                              model_input.num_samples,
667
                                              kv_caches)
668
                next_token_ids.append(output_token_ids[0])
669
                start_idx = end_idx
670

671
672
            if model_input.async_callback is not None:
                model_input.async_callback()
673
            # Retrieve the outputs to CPU.
674
675
676
677
678
679
680
681
682
683
684
685
            next_token_ids = [
                output_token_ids.cpu().tolist()
                for output_token_ids in next_token_ids
            ]

            # NOTE(woosuk): Minimal code to construct the sampler outputs.
            # The TPU backend does not reuse the sampler, since the TPU backend
            # does not support advanced sampling parameters such as logprobs.
            zero_logprob = Logprob(0.0)
            sampler_outputs = []
            for i, seq_group in enumerate(model_input.seq_groups):
                seq_ids = seq_group
686
687
                assert len(seq_ids) == 1
                seq_id = seq_ids[0]
688
                seq_outputs = []
689
                for j in range(model_input.n[i]):
690
                    next_token_id = next_token_ids[i][j]
691
692
693
                    seq_outputs.append(
                        SequenceOutput(seq_id, next_token_id,
                                       {next_token_id: zero_logprob}))
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
                sampler_outputs.append(
                    CompletionSequenceGroupOutput(seq_outputs, None))
            return [SamplerOutput(sampler_outputs)]
        else:
            token_ids = model_input.token_ids.to(self.device)
            position_ids = model_input.position_ids.to(self.device)
            attn_metadata = model_input.attn_metadata
            attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
                self.device)
            attn_metadata.block_tables = attn_metadata.block_tables.to(
                self.device)
            attn_metadata.context_lens = attn_metadata.context_lens.to(
                self.device)
            t = model_input.t.to(self.device)
            p = model_input.p.to(self.device)
            input_lens = model_input.input_lens.to(self.device)
            for i in range(num_steps):
                slot_mapping = attn_metadata.slot_mapping
712
713
                output_token_ids = self.model(token_ids, position_ids,
                                              attn_metadata, input_lens, t, p,
714
                                              model_input.num_samples,
715
                                              kv_caches)
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
                self.cached_step_outputs.append(output_token_ids)

                if i < num_steps - 1:
                    # Prepare the inputs for the next step.
                    token_ids = output_token_ids.unsqueeze(dim=1).int()
                    position_ids = position_ids + 1
                    attn_metadata.context_lens = attn_metadata.context_lens + 1

                    block_tables = attn_metadata.block_tables
                    block_number = block_tables.gather(
                        1,
                        position_ids.long() // self.block_size)
                    block_offset = position_ids % self.block_size

                    is_padding = slot_mapping == _PAD_SLOT_ID
                    slot_mapping = block_number * self.block_size + block_offset
                    slot_mapping = slot_mapping.long()
                    slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
                                               slot_mapping)
                    attn_metadata.slot_mapping = slot_mapping

            if model_input.async_callback is not None:
                model_input.async_callback()

            if num_steps > 1:
                return []
            # Retrieve the outputs to CPU.
            next_token_ids = self.cached_step_outputs.pop(0)
            next_token_ids = next_token_ids.cpu().tolist()
            sampler_output = _make_decode_output(next_token_ids,
                                                 model_input.seq_groups)
            return [sampler_output]
748
749


750
class ModelWrapper(nn.Module):
751

752
753
    def __init__(self, model: nn.Module):
        super().__init__()
754
        self.model = model
755
756
757
758
759
760
761
762
763

    def forward(
        self,
        token_ids: torch.Tensor,
        position_ids: torch.Tensor,
        attn_metadata: AttentionMetadata,
        input_lens: torch.Tensor,
        t: torch.Tensor,
        p: torch.Tensor,
764
        num_samples: int,
765
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
766
767
768
769
770
771
772
773
774
775
    ) -> torch.Tensor:
        """Executes the forward pass of the model and samples the next token.

        Args:
            token_ids: The input token IDs of shape [batch_size, seq_len].
            position_ids: The input position IDs of shape [batch_size, seq_len].
            attn_metadata: The Pallas attention metadata.
            input_lens: The actual input lengths of shape [batch_size].
            t: The sampling temperature of shape [batch_size].
            p: The top-p probability of shape [batch_size].
776
777
778
            num_samples: Number of samples to draw from each logits vector.
            kv_caches: The key and value caches. They can be None during the
                memory profiling at initialization.
779
780
781
        """
        batch_size, seq_len = token_ids.shape
        # Calculate the positions to sample from.
782
        start_indicies = torch.arange(
783
            batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
784
        logits_indices = start_indicies + input_lens - 1
785
786
787
788
789
790
791
792
793
794
795

        # FIXME(woosuk): This is a temporary hack to avoid using the existing
        # sampler and sampling metadata.
        sampling_metadata = SamplingMetadata(
            seq_groups=[],
            selected_token_indices=logits_indices,
            categorized_sample_indices={},
            num_prompts=attn_metadata.num_prefills,
        )

        # Skip this in memory profiling at initialization.
796
        if kv_caches[0][0].numel() > 0:
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
            # index_copy_(slot_mapping) only works when the inserted dimension
            # is 0. However, the KV cache in the Pallas backend has the shape
            # [num_kv_heads, num_blocks, block_size, head_size]. To make it
            # work, we need to flatten the first three dimensions and modify
            # the slot_mapping accordingly.
            num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
            slot_mapping = attn_metadata.slot_mapping
            slot_mapping = slot_mapping.flatten()
            head_indicies = torch.arange(0,
                                         num_kv_heads,
                                         device=slot_mapping.device,
                                         dtype=slot_mapping.dtype)
            head_indicies *= block_size * num_blocks
            slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
                -1, num_kv_heads)
            slot_mapping = slot_mapping + head_indicies.view(1, -1)
            slot_mapping = slot_mapping.flatten()
            attn_metadata.slot_mapping = slot_mapping

        hidden_states = self.model(
            token_ids,
            position_ids,
            kv_caches,
            attn_metadata,
        )
        hidden_states = hidden_states.flatten(0, 1)
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

825
826
827
828
829
830
831
        # Argmax sampling.
        argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
        argmax_token_ids = argmax_token_ids.repeat(1, num_samples)

        # Zero temperature means greedy decoding. Avoid division by zero.
        nonzero_t = torch.where(t != 0, t, 1.0)
        logits = logits / nonzero_t.unsqueeze(dim=1)
832
833
        if _ENABLE_TOP_P:
            logits = _apply_top_p(logits, p.unsqueeze(dim=1))
834
835

        # Random sampling.
836
        probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
837
838
839
        sampled_token_ids = torch.multinomial(probs,
                                              num_samples,
                                              replacement=True)
840
841
842
        if num_samples == 1:
            argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
            sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
843
844
        next_token_ids = torch.where(t != 0, sampled_token_ids,
                                     argmax_token_ids)
845
846
847
848
849
850
851
852
853
854
855
856
857
        return next_token_ids


def _get_padded_prefill_len(x: int) -> int:
    # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
    # length to be a multiple of 16. We pad the prompt length to the nearest
    # multiple of 16. This is also good for performance.
    if x <= 16:
        return 16
    return 1 << (x - 1).bit_length()


def _get_padded_batch_size(batch_size: int) -> int:
858
859
860
861
    # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
    # To meet this requirement in the simplest way, we set the minimal batch
    # size to 8.
    if batch_size <= 8:
862
863
864
865
866
867
868
869
870
871
872
873
        return 8
    else:
        return ((batch_size + 15) // 16) * 16


def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
    logits_sorted = torch.sort(logits, dim=-1, descending=True).values
    sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
    cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
    cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
    logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
    return logits
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894


def _make_decode_output(
    next_token_ids: List[int],
    seq_groups: List[List[int]],
) -> SamplerOutput:
    zero_logprob = Logprob(0.0)
    sampler_outputs = []
    batch_idx = 0
    for seq_group in seq_groups:
        seq_ids = seq_group
        seq_outputs = []
        for seq_id in seq_ids:
            next_token_id = next_token_ids[batch_idx]
            seq_outputs.append(
                SequenceOutput(seq_id, next_token_id,
                               {next_token_id: zero_logprob}))
            batch_idx += 1
        sampler_outputs.append(CompletionSequenceGroupOutput(
            seq_outputs, None))
    return SamplerOutput(sampler_outputs)