"deploy/operator/internal/consts/consts.go" did not exist on "81c278038f771b44a3fc2cb9976bd396a2d50777"
tpu_model_runner.py 39.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import enum
5
import time
6
from dataclasses import dataclass
7
8
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Type, Union)
9
from unittest.mock import patch
10
11
12
13
14

import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
15
import torch_xla.runtime as xr
16
17

from vllm.attention import AttentionMetadata, get_attn_backend
18
from vllm.config import VllmConfig
19
from vllm.forward_context import get_forward_context, set_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.sampler import SamplerOutput
22
23
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
24
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
25
                           Logprob, SequenceGroupMetadata, SequenceOutput)
26
27
28
29
30
31
32
from vllm.worker.model_runner_base import (
    ModelRunnerBase, ModelRunnerInputBase,
    _add_attn_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
33
34
35

logger = init_logger(__name__)

36
37
38
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
39
40
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
41
42
43
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
44
45


46
47
48
49
50
51
52
53
54
class ExecutionMode(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()
    PREFIX_PREFILL = enum.auto()

    def is_prefill(self) -> bool:
        return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)


55
56
57
58
59
60
61
62
63
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
    token_ids: torch.Tensor
    position_ids: torch.Tensor
    attn_metadata: AttentionMetadata
    input_lens: torch.Tensor
    t: torch.Tensor
    p: torch.Tensor
    num_samples: int
64
    n: List[int]
65
    seq_groups: List[List[int]]
66
67
    is_first_multi_step: bool = True
    is_last_step: bool = True
68
    virtual_engine: int = 0
69
    async_callback: Optional[Callable] = None
70
71
72
73
74
75
76
77
78
79

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "token_ids": self.token_ids,
            "position_ids": self.position_ids,
            "input_lens": self.input_lens,
            "t": self.t,
            "p": self.p,
            "num_samples": self.num_samples,
80
            "n": self.n,
81
            "seq_groups": self.seq_groups,
82
83
            "is_first_multi_step": self.is_first_multi_step,
            "is_last_step": self.is_last_step,
84
            "virtual_engine": self.virtual_engine,
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["ModelInputForTPU"],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForTPU":
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
102
103
104

    def __init__(
        self,
105
        vllm_config: VllmConfig,
106
        is_driver_worker: bool = False,
107
    ):
108
        ModelRunnerBase.__init__(self, vllm_config=vllm_config)
109
        self.is_driver_worker = is_driver_worker
110
111
112
113
114
115
116
117
118
119
120
121

        self.block_size = self.cache_config.block_size
        self.max_num_blocks_per_seq = (self.model_config.max_model_len //
                                       self.block_size)
        self.block_tables = np.zeros(
            (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
            dtype=np.int32)
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.cache_config.cache_dtype,
            self.block_size,
122
            self.model_config.is_attention_free,
123
124
            False,
        )
125
        self.cached_step_outputs: List[torch.Tensor] = []
126

127
128
129
130
131
132
        smem_size = 512 * 1024
        block_table_size = 4 * self.block_tables.size
        if block_table_size >= smem_size:
            logger.warning(
                "The max_model_len (%d) is too large. This may degrade the "
                "performance due to the insufficient smem size. Consider "
133
134
135
136
                "setting --max-model-len to a smaller value, like %d.",
                self.model_config.max_model_len,
                self.model_config.max_model_len /
                (block_table_size / smem_size))
137

138
139
140
    def load_model(self) -> None:
        self.device = self.device_config.device

141
142
143
144
145
146
147
148
149
        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
150
        xm_tp_rank = xr.global_ordinal()
151
152
153
154
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
155
            model = get_model(vllm_config=self.vllm_config)
156
        model = model.eval()
157
        xm.wait_device_ops()
158
159
160
161
162
        model = ModelWrapper(model)
        self.model = torch.compile(model,
                                   backend="openxla",
                                   fullgraph=True,
                                   dynamic=False)
163

164
165
166
    def get_model(self) -> nn.Module:
        return self.model.model

167
168
169
170
171
    def _dummy_run(
        self,
        batch_size: int,
        seq_len: int,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
172
        exec_mode: ExecutionMode,
173
    ) -> None:
174
175
        exec_mode = ExecutionMode(exec_mode)
        if exec_mode.is_prefill():
176
177
178
179
180
181
182
183
184
185
186
187
188
            seq_len = (seq_len + 15) // 16 * 16
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
189
190
191
192
193
194
195
            if exec_mode == ExecutionMode.PREFILL:
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=batch_size,
                    num_prefill_tokens=batch_size * seq_len,
                    num_decode_tokens=0,
                    slot_mapping=slot_mapping,
                    multi_modal_placeholder_index_maps=None,
196
                    enable_kv_scales_calculation=False,
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                    block_tables=None,
                    context_lens=None,
                    effective_query_lens=None,
                )
            else:
                context_lens = torch.ones((batch_size, ),
                                          dtype=torch.int32,
                                          device=self.device)
                block_tables = torch.tensor(self.block_tables[:batch_size],
                                            dtype=torch.int32,
                                            device=self.device)
                effective_query_lens = torch.ones_like(context_lens)
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=batch_size,
                    num_prefill_tokens=batch_size * seq_len,
                    num_decode_tokens=0,
                    slot_mapping=slot_mapping,
                    multi_modal_placeholder_index_maps=None,
215
                    enable_kv_scales_calculation=False,
216
217
218
219
                    block_tables=block_tables,
                    context_lens=context_lens,
                    effective_query_lens=effective_query_lens,
                )
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        else:
            assert seq_len == 1
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            block_tables = torch.zeros(
                (batch_size, self.max_num_blocks_per_seq),
                dtype=torch.int32,
                device=self.device)
            context_lens = torch.ones((batch_size, ),
                                      dtype=torch.int32,
                                      device=self.device)
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
            attn_metadata = self.attn_backend.make_metadata(
                num_prefills=0,
                num_prefill_tokens=0,
                num_decode_tokens=batch_size * seq_len,
                slot_mapping=slot_mapping,
246
                multi_modal_placeholder_index_maps=None,
247
                enable_kv_scales_calculation=False,
248
249
250
251
252
                block_tables=block_tables,
                context_lens=context_lens,
            )
        t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
        p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
253
        num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
254

255
256
257
258
259
260
261
        # NOTE(woosuk): There are two stages of compilation: torch.compile and
        # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
        # overhead by reusing the FX graph for different shapes.
        # However, the XLA graph will still require static shapes and needs to
        # be re-compiled for every different shapes. This overhead is inevitable
        # in the first run, but can be skipped afterwards as we cache the XLA
        # graphs in the disk (VLLM_XLA_CACHE_PATH).
262
        if exec_mode.is_prefill():
263
264
265
266
267
268
269
270
271
272
273
274
275
276
            # Prefll
            torch._dynamo.mark_dynamic(token_ids, 1)
            torch._dynamo.mark_dynamic(position_ids, 1)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
        else:
            # Decode
            torch._dynamo.mark_dynamic(token_ids, 0)
            torch._dynamo.mark_dynamic(position_ids, 0)
            torch._dynamo.mark_dynamic(input_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
            torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
            torch._dynamo.mark_dynamic(t, 0)
            torch._dynamo.mark_dynamic(p, 0)
277
        # Dummy run.
278
        with set_forward_context(attn_metadata, self.vllm_config, 0):
279
280
            self.model(token_ids, position_ids, input_lens, t, p, num_samples,
                       kv_caches)
281
282
283
284
285
286
287
288
289
290

    def warmup_model(
        self,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
    ) -> None:
        # Prefill
        logger.info("Compiling the model with different input shapes...")
        start = time.time()
        for batch_size in [1]:
            seq_len = 16
291
292
293
294
295
            while seq_len <= self.model_config.max_model_len:
                self._dummy_run(batch_size,
                                seq_len,
                                kv_caches,
                                exec_mode=ExecutionMode.PREFILL)
296
297
298
299
300
301
302
303
304
305
                xm.wait_device_ops()
                logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
                num_tokens = batch_size * seq_len
                if num_tokens >= self.scheduler_config.max_num_batched_tokens:
                    break
                seq_len = seq_len * 2

        end = time.time()
        logger.info("Compilation for prefill done in %.2f s.", end - start)

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        # Prefix prefill
        if self.cache_config.enable_prefix_caching:
            logger.info("Compiling the model with different input shapes for "
                        "prefix prefill...")
            start = time.time()
            for batch_size in [1]:
                seq_len = 16
                while seq_len <= self.model_config.max_model_len:
                    self._dummy_run(batch_size,
                                    seq_len,
                                    kv_caches,
                                    exec_mode=ExecutionMode.PREFIX_PREFILL)
                    xm.wait_device_ops()
                    logger.info("batch_size: %d, seq_len: %d", batch_size,
                                seq_len)
                    num_tokens = batch_size * seq_len
322
323
                    if (num_tokens
                            >= self.scheduler_config.max_num_batched_tokens):
324
325
326
327
328
329
                        break
                    seq_len = seq_len * 2
            end = time.time()
            logger.info("Compilation for prefix prefill done in %.2f s.",
                        end - start)

330
331
332
        # Decode
        start = time.time()
        seq_len = 1
333
        batch_size = 8  # Must be in sync with _get_padded_batch_size()
334
        while True:
335
336
337
338
            self._dummy_run(batch_size,
                            seq_len,
                            kv_caches,
                            exec_mode=ExecutionMode.DECODE)
339
340
341
342
343
344
345
346
347
348
349
350
351
            xm.wait_device_ops()
            logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)

            if batch_size >= self.scheduler_config.max_num_seqs:
                break
            batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2

        end = time.time()
        logger.info("Compilation for decode done in %.2f s.", end - start)

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
352
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
353
        assert len(seq_group_metadata_list) > 0
354
355
        input_tokens: List[int] = []
        input_positions: List[int] = []
356
        prompt_lens: List[int] = []
357
        context_lens: List[int] = []
358
        slot_mapping: List[int] = []
359

360
361
        for batch_idx, seq_group_metadata in enumerate(
                seq_group_metadata_list):
362
363
364
365
366
367
368
369
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            # Could include output tokens when a request is preempted.
            prompt_tokens = seq_data.get_token_ids()
370
371
372
373
374
375
376
377
378
379
            seq_len = len(prompt_tokens)

            num_computed_blocks = len(seq_group_metadata.computed_block_nums)
            num_computed_tokens = num_computed_blocks * self.block_size
            if num_computed_tokens > 0:
                prompt_tokens = prompt_tokens[num_computed_tokens:]
                context_lens.append(seq_len)
            else:
                context_lens.append(0)

380
381
382
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)

383
            input_tokens.extend(prompt_tokens)
384
            input_positions.extend(range(num_computed_tokens, seq_len))
385
386
387

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
388
            for i in range(num_computed_tokens, seq_len):
389
390
391
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
392
                slot_mapping.append(slot)
393
394
            if num_computed_tokens > 0:
                self.block_tables[batch_idx, :len(block_table)] = block_table
395
396
397
398
399
400
401
402
403
404
405
406

            # Add paddings to EACH prompt to the smallest power of 2 that is
            # greater than or equal to the prompt length.
            # We pad the seq_len to reduce the compilation overhead.
            # We execute each prompt individually (i.e., with batch_size 1)
            # because the FlashAttention kernel does not support ragged inputs.
            # TODO(woosuk): Use SplashAttention to support ragged inputs.
            padded_prompt_len = _get_padded_prefill_len(prompt_len)
            num_paddings = padded_prompt_len - prompt_len
            input_tokens += [0] * num_paddings
            input_positions += [0] * num_paddings
            slot_mapping += [_PAD_SLOT_ID] * num_paddings
407
408
409

        assert len(prompt_lens) > 0
        num_prefills = len(prompt_lens)
410
411
412
413
414
415
416
417
418
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
                                    device="cpu")
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
                                       device="cpu")
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
                                    device="cpu")
419
420
        prompt_lens = torch.tensor(prompt_lens,
                                   dtype=torch.int32,
421
                                   device="cpu")
422
423
424
425
426
427
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int32,
                                    device="cpu")
        block_tables = torch.tensor(self.block_tables[:num_prefills],
                                    dtype=torch.int32,
                                    device="cpu")
428
429
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=num_prefills,
430
            num_prefill_tokens=0,  # NOTE: This is not used.
431
432
            num_decode_tokens=0,
            slot_mapping=slot_mapping,
433
            multi_modal_placeholder_index_maps=None,
434
            enable_kv_scales_calculation=False,
435
436
437
            block_tables=block_tables,
            context_lens=context_lens,
            effective_query_lens=prompt_lens,
438
        )
439
        return input_tokens, input_positions, attn_metadata, prompt_lens
440
441
442
443

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
444
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
445
446
447
448
449
450
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
        context_lens: List[int] = []

451
452
        batch_idx = 0
        for seq_group_metadata in seq_group_metadata_list:
453
454
455
456
457
458
459
460
461
462
463
464
465
466
            assert not seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
467
468
                self.block_tables[batch_idx, :len(block_table)] = block_table
                batch_idx += 1
469
470
471
472
473
474

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append([slot])

475
476
        batch_size = _get_padded_batch_size(batch_idx)
        num_paddings = batch_size - batch_idx
477
478
479
480
481
482
483
        input_tokens = input_tokens + [[0]] * num_paddings
        input_positions = input_positions + [[0]] * num_paddings
        slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
        context_lens = context_lens + [0] * num_paddings

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
484
                                    device="cpu")
485
486
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
487
                                       device="cpu")
488
489
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
490
                                    device="cpu")
491
492
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int32,
493
                                    device="cpu")
494
495
        block_tables = torch.tensor(self.block_tables[:batch_size],
                                    dtype=torch.int32,
496
                                    device="cpu")
497
498
        input_lens = torch.tensor([1] * batch_size,
                                  dtype=torch.int32,
499
                                  device="cpu")
500
501
502
503
504
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            slot_mapping=slot_mapping,
505
            multi_modal_placeholder_index_maps=None,
506
            enable_kv_scales_calculation=False,
507
508
509
            block_tables=block_tables,
            context_lens=context_lens,
        )
510
        return input_tokens, input_positions, attn_metadata, input_lens
511
512
513
514
515

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        padded_batch_size: int,
516
    ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
517
518
519
        assert len(seq_group_metadata_list) > 0
        t = []
        p = []
520
        n = []
521
522
        for seq_group_metadata in seq_group_metadata_list:
            sampling_params = seq_group_metadata.sampling_params
523
            t.append(sampling_params.temperature)
524
525
526
527
            if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
                raise NotImplementedError(
                    "Top-p sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
528
            p.append(sampling_params.top_p)
529
            if sampling_params.top_k > 0:
530
531
532
                raise NotImplementedError(
                    "Top-k sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
533
            if sampling_params.n > _MAX_NUM_SAMPLES:
534
                raise NotImplementedError(
535
                    f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
536
                    "backend.")
537
            n.append(sampling_params.n)
538
539
540
541
542
543
544
545
            if sampling_params.logprobs is not None:
                raise NotImplementedError(
                    "logprobs is not currently supported by the TPU backend.")
            if sampling_params.prompt_logprobs is not None:
                raise NotImplementedError(
                    "prompt_logprobs is not currently supported by the TPU "
                    "backend.")

546
547
548
549
            # Repeat the sampling params if the seq group has multiple seqs.
            num_seqs = len(seq_group_metadata.seq_data)
            t += [t[-1]] * (num_seqs - 1)
            p += [p[-1]] * (num_seqs - 1)
550
            n += [n[-1]] * (num_seqs - 1)
551
552

        num_paddings = padded_batch_size - len(t)
553
554
555
        t += [1.0] * num_paddings
        p += [1.0] * num_paddings

556
557
        t = torch.tensor(t, dtype=torch.float32, device="cpu")
        p = torch.tensor(p, dtype=torch.float32, device="cpu")
558
        return t, p, n
559

560
    def prepare_model_input(
561
        self,
562
        seq_group_metadata_list: List[SequenceGroupMetadata],
563
564
565
566
567
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None,
    ) -> ModelInputForTPU:
        del finished_requests_ids  # Unused.
        assert virtual_engine == 0
568
569
570
        assert len(seq_group_metadata_list) > 0
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
571
572
        is_prompt = seq_group_metadata_list[0].is_prompt
        if is_prompt:
573
574
575
            inputs = self._prepare_prompt(seq_group_metadata_list)
        else:
            inputs = self._prepare_decode(seq_group_metadata_list)
576
577
        input_tokens, input_positions, attn_metadata, input_lens = inputs
        padded_batch_size = input_tokens.shape[0]
578
579
        t, p, n = self._prepare_sample(seq_group_metadata_list,
                                       padded_batch_size)
580
        num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
581

582
583
584
585
586
        seq_groups = [
            list(metadata.seq_data.keys())
            for metadata in seq_group_metadata_list
        ]
        return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
587
                                input_lens, t, p, num_samples, n, seq_groups)
588
589
590
591
592
593
594

    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
        model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=self.attn_backend)
        return model_input

595
    @torch.no_grad()
596
597
598
    def execute_model(
        self,
        model_input: ModelInputForTPU,
599
        kv_caches: Optional[List[Any]],
600
601
602
603
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> List[SamplerOutput]:
        assert intermediate_tensors is None
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
        if not model_input.is_first_multi_step:
            if not model_input.is_last_step:
                return []

            use_async_out_proc = model_input.async_callback is not None
            sampler_outputs = []
            num_outputs = len(self.cached_step_outputs)
            for i in range(num_outputs):
                next_token_ids = self.cached_step_outputs.pop(0)
                next_token_ids = next_token_ids.cpu().tolist()
                sampler_output = _make_decode_output(next_token_ids,
                                                     model_input.seq_groups)
                sampler_outputs.append(sampler_output)

                if i < num_outputs - 1 and use_async_out_proc:
                    assert model_input.async_callback is not None
                    ctx = model_input.async_callback.keywords[  # type: ignore
                        "ctx"]
                    ctx.append_output(
                        outputs=[sampler_output],
                        seq_group_metadata_list=ctx.seq_group_metadata_list,
                        scheduler_outputs=ctx.scheduler_outputs,
                        is_async=False,
627
628
                        is_last_step=False,
                        is_first_step_output=i == 0)
629
630
631
632
633
634
635
                    model_input.async_callback()
            if use_async_out_proc:
                return [sampler_outputs[-1]]
            else:
                return sampler_outputs

        is_prompt = model_input.attn_metadata.num_prefills > 0
636
        if is_prompt:
637
            assert num_steps == 1
638
639
640
641
642
            # NOTE(woosuk): Since the FlashAttention kernel does not support
            # ragged inputs, we split the prompts into different batches and
            # process them separately. This is a temporary hack that should be
            # optimized by using SplashAttention.
            orig_slot_mapping = model_input.attn_metadata.slot_mapping
643
644
645
646
            orig_block_tables = model_input.attn_metadata.block_tables
            orig_context_lens = model_input.attn_metadata.context_lens
            orig_effective_query_lens = \
                model_input.attn_metadata.effective_query_lens
647
648
            batch_size = model_input.input_lens.shape[0]
            start_idx = 0
649
            next_token_ids = []
650
651
652
653
654
655
            for i in range(batch_size):
                # Get the actual prefill_len.
                prefill_len = model_input.input_lens[i:i + 1].item()
                prefill_len = _get_padded_prefill_len(prefill_len)
                end_idx = start_idx + prefill_len

656
657
658
659
660
661
662
663
664
                token_ids = model_input.token_ids[None, start_idx:end_idx].to(
                    self.device)
                position_ids = model_input.position_ids[None,
                                                        start_idx:end_idx].to(
                                                            self.device)
                attn_metadata = model_input.attn_metadata
                attn_metadata.num_prefills = 1
                attn_metadata.slot_mapping = orig_slot_mapping[
                    None, start_idx:end_idx].to(self.device)
665
666
667
668
669
670
671
672
673
674
675
                if orig_context_lens[i].item() > 0:
                    attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
                        self.device)
                    attn_metadata.block_tables = orig_block_tables[
                        i].unsqueeze(0).to(self.device)
                    attn_metadata.effective_query_lens = \
                        orig_effective_query_lens[i:i + 1].to(self.device)
                else:
                    attn_metadata.context_lens = None
                    attn_metadata.block_tables = None
                    attn_metadata.effective_query_lens = None
676
677
678
                input_lens = model_input.input_lens[i:i + 1].to(self.device)
                t = model_input.t[i:i + 1].to(self.device)
                p = model_input.p[i:i + 1].to(self.device)
679
680
681
682
                with set_forward_context(model_input.attn_metadata,
                                         self.vllm_config,
                                         model_input.virtual_engine):
                    output_token_ids = self.model(token_ids, position_ids,
683
684
                                                  input_lens, t, p,
                                                  model_input.num_samples,
685
                                                  kv_caches)
686
                next_token_ids.append(output_token_ids[0])
687
                start_idx = end_idx
688

689
690
            if model_input.async_callback is not None:
                model_input.async_callback()
691
            # Retrieve the outputs to CPU.
692
693
694
695
696
697
698
699
700
701
702
703
            next_token_ids = [
                output_token_ids.cpu().tolist()
                for output_token_ids in next_token_ids
            ]

            # NOTE(woosuk): Minimal code to construct the sampler outputs.
            # The TPU backend does not reuse the sampler, since the TPU backend
            # does not support advanced sampling parameters such as logprobs.
            zero_logprob = Logprob(0.0)
            sampler_outputs = []
            for i, seq_group in enumerate(model_input.seq_groups):
                seq_ids = seq_group
704
705
                assert len(seq_ids) == 1
                seq_id = seq_ids[0]
706
                seq_outputs = []
707
                for j in range(model_input.n[i]):
708
                    next_token_id = next_token_ids[i][j]
709
710
711
                    seq_outputs.append(
                        SequenceOutput(seq_id, next_token_id,
                                       {next_token_id: zero_logprob}))
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
                sampler_outputs.append(
                    CompletionSequenceGroupOutput(seq_outputs, None))
            return [SamplerOutput(sampler_outputs)]
        else:
            token_ids = model_input.token_ids.to(self.device)
            position_ids = model_input.position_ids.to(self.device)
            attn_metadata = model_input.attn_metadata
            attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
                self.device)
            attn_metadata.block_tables = attn_metadata.block_tables.to(
                self.device)
            attn_metadata.context_lens = attn_metadata.context_lens.to(
                self.device)
            t = model_input.t.to(self.device)
            p = model_input.p.to(self.device)
            input_lens = model_input.input_lens.to(self.device)
            for i in range(num_steps):
                slot_mapping = attn_metadata.slot_mapping
730
731
732
733
                with set_forward_context(model_input.attn_metadata,
                                         self.vllm_config,
                                         model_input.virtual_engine):
                    output_token_ids = self.model(token_ids, position_ids,
734
735
                                                  input_lens, t, p,
                                                  model_input.num_samples,
736
                                                  kv_caches)
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
                self.cached_step_outputs.append(output_token_ids)

                if i < num_steps - 1:
                    # Prepare the inputs for the next step.
                    token_ids = output_token_ids.unsqueeze(dim=1).int()
                    position_ids = position_ids + 1
                    attn_metadata.context_lens = attn_metadata.context_lens + 1

                    block_tables = attn_metadata.block_tables
                    block_number = block_tables.gather(
                        1,
                        position_ids.long() // self.block_size)
                    block_offset = position_ids % self.block_size

                    is_padding = slot_mapping == _PAD_SLOT_ID
                    slot_mapping = block_number * self.block_size + block_offset
                    slot_mapping = slot_mapping.long()
                    slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
                                               slot_mapping)
                    attn_metadata.slot_mapping = slot_mapping

            if model_input.async_callback is not None:
                model_input.async_callback()

            if num_steps > 1:
                return []
            # Retrieve the outputs to CPU.
            next_token_ids = self.cached_step_outputs.pop(0)
            next_token_ids = next_token_ids.cpu().tolist()
            sampler_output = _make_decode_output(next_token_ids,
                                                 model_input.seq_groups)
            return [sampler_output]
769
770


771
class ModelWrapper(nn.Module):
772

773
774
    def __init__(self, model: nn.Module):
        super().__init__()
775
        self.model = model
776
777
778
779
780
781
782
783

    def forward(
        self,
        token_ids: torch.Tensor,
        position_ids: torch.Tensor,
        input_lens: torch.Tensor,
        t: torch.Tensor,
        p: torch.Tensor,
784
        num_samples: int,
785
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
786
787
788
789
790
791
792
793
794
    ) -> torch.Tensor:
        """Executes the forward pass of the model and samples the next token.

        Args:
            token_ids: The input token IDs of shape [batch_size, seq_len].
            position_ids: The input position IDs of shape [batch_size, seq_len].
            input_lens: The actual input lengths of shape [batch_size].
            t: The sampling temperature of shape [batch_size].
            p: The top-p probability of shape [batch_size].
795
796
797
            num_samples: Number of samples to draw from each logits vector.
            kv_caches: The key and value caches. They can be None during the
                memory profiling at initialization.
798
799
800
        """
        batch_size, seq_len = token_ids.shape
        # Calculate the positions to sample from.
801
        start_indicies = torch.arange(
802
            batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
803
        logits_indices = start_indicies + input_lens - 1
804
        attn_metadata = get_forward_context().attn_metadata
805
806
807
808
809
810
811
812
813
814
815

        # FIXME(woosuk): This is a temporary hack to avoid using the existing
        # sampler and sampling metadata.
        sampling_metadata = SamplingMetadata(
            seq_groups=[],
            selected_token_indices=logits_indices,
            categorized_sample_indices={},
            num_prompts=attn_metadata.num_prefills,
        )

        # Skip this in memory profiling at initialization.
816
        if kv_caches[0][0].numel() > 0:
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
            # index_copy_(slot_mapping) only works when the inserted dimension
            # is 0. However, the KV cache in the Pallas backend has the shape
            # [num_kv_heads, num_blocks, block_size, head_size]. To make it
            # work, we need to flatten the first three dimensions and modify
            # the slot_mapping accordingly.
            num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
            slot_mapping = attn_metadata.slot_mapping
            slot_mapping = slot_mapping.flatten()
            head_indicies = torch.arange(0,
                                         num_kv_heads,
                                         device=slot_mapping.device,
                                         dtype=slot_mapping.dtype)
            head_indicies *= block_size * num_blocks
            slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
                -1, num_kv_heads)
            slot_mapping = slot_mapping + head_indicies.view(1, -1)
            slot_mapping = slot_mapping.flatten()
            attn_metadata.slot_mapping = slot_mapping

836
        hidden_states = self.model(token_ids, position_ids)
837
838
839
        hidden_states = hidden_states.flatten(0, 1)
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

840
841
842
843
844
845
846
        # Argmax sampling.
        argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
        argmax_token_ids = argmax_token_ids.repeat(1, num_samples)

        # Zero temperature means greedy decoding. Avoid division by zero.
        nonzero_t = torch.where(t != 0, t, 1.0)
        logits = logits / nonzero_t.unsqueeze(dim=1)
847
848
        if _ENABLE_TOP_P:
            logits = _apply_top_p(logits, p.unsqueeze(dim=1))
849
850

        # Random sampling.
851
        probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
852
853
854
        sampled_token_ids = torch.multinomial(probs,
                                              num_samples,
                                              replacement=True)
855
856
857
        if num_samples == 1:
            argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
            sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
858
859
        next_token_ids = torch.where(t != 0, sampled_token_ids,
                                     argmax_token_ids)
860
861
862
863
864
865
866
867
868
869
870
871
872
        return next_token_ids


def _get_padded_prefill_len(x: int) -> int:
    # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
    # length to be a multiple of 16. We pad the prompt length to the nearest
    # multiple of 16. This is also good for performance.
    if x <= 16:
        return 16
    return 1 << (x - 1).bit_length()


def _get_padded_batch_size(batch_size: int) -> int:
873
874
875
876
    # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
    # To meet this requirement in the simplest way, we set the minimal batch
    # size to 8.
    if batch_size <= 8:
877
878
879
880
881
882
883
884
885
886
887
888
        return 8
    else:
        return ((batch_size + 15) // 16) * 16


def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
    logits_sorted = torch.sort(logits, dim=-1, descending=True).values
    sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
    cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
    cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
    logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
    return logits
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909


def _make_decode_output(
    next_token_ids: List[int],
    seq_groups: List[List[int]],
) -> SamplerOutput:
    zero_logprob = Logprob(0.0)
    sampler_outputs = []
    batch_idx = 0
    for seq_group in seq_groups:
        seq_ids = seq_group
        seq_outputs = []
        for seq_id in seq_ids:
            next_token_id = next_token_ids[batch_idx]
            seq_outputs.append(
                SequenceOutput(seq_id, next_token_id,
                               {next_token_id: zero_logprob}))
            batch_idx += 1
        sampler_outputs.append(CompletionSequenceGroupOutput(
            seq_outputs, None))
    return SamplerOutput(sampler_outputs)