tpu_model_runner.py 39.9 KB
Newer Older
1
import enum
2
import time
3
from dataclasses import dataclass
4
5
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Type, Union)
6
from unittest.mock import patch
7
8
9
10
11

import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
12
import torch_xla.runtime as xr
13
14

from vllm.attention import AttentionMetadata, get_attn_backend
15
from vllm.config import VllmConfig
16
from vllm.forward_context import set_forward_context
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.sampler import SamplerOutput
19
20
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
21
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
22
                           Logprob, SequenceGroupMetadata, SequenceOutput)
23
24
25
26
27
28
29
from vllm.worker.model_runner_base import (
    ModelRunnerBase, ModelRunnerInputBase,
    _add_attn_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
30
31
32

logger = init_logger(__name__)

33
34
35
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
36
37
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
38
39
40
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
41
42


43
44
45
46
47
48
49
50
51
class ExecutionMode(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()
    PREFIX_PREFILL = enum.auto()

    def is_prefill(self) -> bool:
        return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)


52
53
54
55
56
57
58
59
60
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
    token_ids: torch.Tensor
    position_ids: torch.Tensor
    attn_metadata: AttentionMetadata
    input_lens: torch.Tensor
    t: torch.Tensor
    p: torch.Tensor
    num_samples: int
61
    n: List[int]
62
    seq_groups: List[List[int]]
63
64
    is_first_multi_step: bool = True
    is_last_step: bool = True
65
    virtual_engine: int = 0
66
    async_callback: Optional[Callable] = None
67
68
69
70
71
72
73
74
75
76

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "token_ids": self.token_ids,
            "position_ids": self.position_ids,
            "input_lens": self.input_lens,
            "t": self.t,
            "p": self.p,
            "num_samples": self.num_samples,
77
            "n": self.n,
78
            "seq_groups": self.seq_groups,
79
80
            "is_first_multi_step": self.is_first_multi_step,
            "is_last_step": self.is_last_step,
81
            "virtual_engine": self.virtual_engine,
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["ModelInputForTPU"],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForTPU":
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
99
100
101

    def __init__(
        self,
102
        vllm_config: VllmConfig,
103
        is_driver_worker: bool = False,
104
    ):
105
        ModelRunnerBase.__init__(self, vllm_config=vllm_config)
106
        self.is_driver_worker = is_driver_worker
107
108
109
110
111
112
113
114
115
116
117
118

        self.block_size = self.cache_config.block_size
        self.max_num_blocks_per_seq = (self.model_config.max_model_len //
                                       self.block_size)
        self.block_tables = np.zeros(
            (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
            dtype=np.int32)
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.cache_config.cache_dtype,
            self.block_size,
119
            self.model_config.is_attention_free,
120
121
            False,
        )
122
        self.cached_step_outputs: List[torch.Tensor] = []
123

124
125
126
127
128
129
        smem_size = 512 * 1024
        block_table_size = 4 * self.block_tables.size
        if block_table_size >= smem_size:
            logger.warning(
                "The max_model_len (%d) is too large. This may degrade the "
                "performance due to the insufficient smem size. Consider "
130
131
132
133
                "setting --max-model-len to a smaller value, like %d.",
                self.model_config.max_model_len,
                self.model_config.max_model_len /
                (block_table_size / smem_size))
134

135
136
137
    def load_model(self) -> None:
        self.device = self.device_config.device

138
139
140
141
142
143
144
145
146
        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
147
        xm_tp_rank = xr.global_ordinal()
148
149
150
151
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
152
            model = get_model(vllm_config=self.vllm_config)
153
        model = model.eval()
154
        xm.wait_device_ops()
155
156
157
158
159
        model = ModelWrapper(model)
        self.model = torch.compile(model,
                                   backend="openxla",
                                   fullgraph=True,
                                   dynamic=False)
160

161
162
163
    def get_model(self) -> nn.Module:
        return self.model.model

164
165
166
167
168
    def _dummy_run(
        self,
        batch_size: int,
        seq_len: int,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
169
        exec_mode: ExecutionMode,
170
    ) -> None:
171
172
        exec_mode = ExecutionMode(exec_mode)
        if exec_mode.is_prefill():
173
174
175
176
177
178
179
180
181
182
183
184
185
            seq_len = (seq_len + 15) // 16 * 16
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
186
187
188
189
190
191
192
            if exec_mode == ExecutionMode.PREFILL:
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=batch_size,
                    num_prefill_tokens=batch_size * seq_len,
                    num_decode_tokens=0,
                    slot_mapping=slot_mapping,
                    multi_modal_placeholder_index_maps=None,
193
                    enable_kv_scales_calculation=False,
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
                    block_tables=None,
                    context_lens=None,
                    effective_query_lens=None,
                )
            else:
                context_lens = torch.ones((batch_size, ),
                                          dtype=torch.int32,
                                          device=self.device)
                block_tables = torch.tensor(self.block_tables[:batch_size],
                                            dtype=torch.int32,
                                            device=self.device)
                effective_query_lens = torch.ones_like(context_lens)
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=batch_size,
                    num_prefill_tokens=batch_size * seq_len,
                    num_decode_tokens=0,
                    slot_mapping=slot_mapping,
                    multi_modal_placeholder_index_maps=None,
212
                    enable_kv_scales_calculation=False,
213
214
215
216
                    block_tables=block_tables,
                    context_lens=context_lens,
                    effective_query_lens=effective_query_lens,
                )
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        else:
            assert seq_len == 1
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            block_tables = torch.zeros(
                (batch_size, self.max_num_blocks_per_seq),
                dtype=torch.int32,
                device=self.device)
            context_lens = torch.ones((batch_size, ),
                                      dtype=torch.int32,
                                      device=self.device)
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
            attn_metadata = self.attn_backend.make_metadata(
                num_prefills=0,
                num_prefill_tokens=0,
                num_decode_tokens=batch_size * seq_len,
                slot_mapping=slot_mapping,
243
                multi_modal_placeholder_index_maps=None,
244
                enable_kv_scales_calculation=False,
245
246
247
248
249
                block_tables=block_tables,
                context_lens=context_lens,
            )
        t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
        p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
250
        num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
251

252
253
254
255
256
257
258
        # NOTE(woosuk): There are two stages of compilation: torch.compile and
        # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
        # overhead by reusing the FX graph for different shapes.
        # However, the XLA graph will still require static shapes and needs to
        # be re-compiled for every different shapes. This overhead is inevitable
        # in the first run, but can be skipped afterwards as we cache the XLA
        # graphs in the disk (VLLM_XLA_CACHE_PATH).
259
        if exec_mode.is_prefill():
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            # Prefll
            torch._dynamo.mark_dynamic(token_ids, 1)
            torch._dynamo.mark_dynamic(position_ids, 1)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
        else:
            # Decode
            torch._dynamo.mark_dynamic(token_ids, 0)
            torch._dynamo.mark_dynamic(position_ids, 0)
            torch._dynamo.mark_dynamic(input_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
            torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
            torch._dynamo.mark_dynamic(t, 0)
            torch._dynamo.mark_dynamic(p, 0)
274
        # Dummy run.
275
276
277
        with set_forward_context(attn_metadata, self.vllm_config, 0):
            self.model(token_ids, position_ids, attn_metadata, input_lens, t,
                       p, num_samples, kv_caches)
278
279
280
281
282
283
284
285
286
287

    def warmup_model(
        self,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
    ) -> None:
        # Prefill
        logger.info("Compiling the model with different input shapes...")
        start = time.time()
        for batch_size in [1]:
            seq_len = 16
288
289
290
291
292
            while seq_len <= self.model_config.max_model_len:
                self._dummy_run(batch_size,
                                seq_len,
                                kv_caches,
                                exec_mode=ExecutionMode.PREFILL)
293
294
295
296
297
298
299
300
301
302
                xm.wait_device_ops()
                logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
                num_tokens = batch_size * seq_len
                if num_tokens >= self.scheduler_config.max_num_batched_tokens:
                    break
                seq_len = seq_len * 2

        end = time.time()
        logger.info("Compilation for prefill done in %.2f s.", end - start)

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        # Prefix prefill
        if self.cache_config.enable_prefix_caching:
            logger.info("Compiling the model with different input shapes for "
                        "prefix prefill...")
            start = time.time()
            for batch_size in [1]:
                seq_len = 16
                while seq_len <= self.model_config.max_model_len:
                    self._dummy_run(batch_size,
                                    seq_len,
                                    kv_caches,
                                    exec_mode=ExecutionMode.PREFIX_PREFILL)
                    xm.wait_device_ops()
                    logger.info("batch_size: %d, seq_len: %d", batch_size,
                                seq_len)
                    num_tokens = batch_size * seq_len
                    if (num_tokens >=
                            self.scheduler_config.max_num_batched_tokens):
                        break
                    seq_len = seq_len * 2
            end = time.time()
            logger.info("Compilation for prefix prefill done in %.2f s.",
                        end - start)

327
328
329
        # Decode
        start = time.time()
        seq_len = 1
330
        batch_size = 8  # Must be in sync with _get_padded_batch_size()
331
        while True:
332
333
334
335
            self._dummy_run(batch_size,
                            seq_len,
                            kv_caches,
                            exec_mode=ExecutionMode.DECODE)
336
337
338
339
340
341
342
343
344
345
346
347
348
            xm.wait_device_ops()
            logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)

            if batch_size >= self.scheduler_config.max_num_seqs:
                break
            batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2

        end = time.time()
        logger.info("Compilation for decode done in %.2f s.", end - start)

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
349
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
350
        assert len(seq_group_metadata_list) > 0
351
352
        input_tokens: List[int] = []
        input_positions: List[int] = []
353
        prompt_lens: List[int] = []
354
        context_lens: List[int] = []
355
        slot_mapping: List[int] = []
356

357
358
        for batch_idx, seq_group_metadata in enumerate(
                seq_group_metadata_list):
359
360
361
362
363
364
365
366
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            # Could include output tokens when a request is preempted.
            prompt_tokens = seq_data.get_token_ids()
367
368
369
370
371
372
373
374
375
376
            seq_len = len(prompt_tokens)

            num_computed_blocks = len(seq_group_metadata.computed_block_nums)
            num_computed_tokens = num_computed_blocks * self.block_size
            if num_computed_tokens > 0:
                prompt_tokens = prompt_tokens[num_computed_tokens:]
                context_lens.append(seq_len)
            else:
                context_lens.append(0)

377
378
379
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)

380
            input_tokens.extend(prompt_tokens)
381
            input_positions.extend(range(num_computed_tokens, seq_len))
382
383
384

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
385
            for i in range(num_computed_tokens, seq_len):
386
387
388
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
389
                slot_mapping.append(slot)
390
391
            if num_computed_tokens > 0:
                self.block_tables[batch_idx, :len(block_table)] = block_table
392
393
394
395
396
397
398
399
400
401
402
403

            # Add paddings to EACH prompt to the smallest power of 2 that is
            # greater than or equal to the prompt length.
            # We pad the seq_len to reduce the compilation overhead.
            # We execute each prompt individually (i.e., with batch_size 1)
            # because the FlashAttention kernel does not support ragged inputs.
            # TODO(woosuk): Use SplashAttention to support ragged inputs.
            padded_prompt_len = _get_padded_prefill_len(prompt_len)
            num_paddings = padded_prompt_len - prompt_len
            input_tokens += [0] * num_paddings
            input_positions += [0] * num_paddings
            slot_mapping += [_PAD_SLOT_ID] * num_paddings
404
405
406

        assert len(prompt_lens) > 0
        num_prefills = len(prompt_lens)
407
408
409
410
411
412
413
414
415
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
                                    device="cpu")
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
                                       device="cpu")
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
                                    device="cpu")
416
417
        prompt_lens = torch.tensor(prompt_lens,
                                   dtype=torch.int32,
418
                                   device="cpu")
419
420
421
422
423
424
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int32,
                                    device="cpu")
        block_tables = torch.tensor(self.block_tables[:num_prefills],
                                    dtype=torch.int32,
                                    device="cpu")
425
426
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=num_prefills,
427
            num_prefill_tokens=0,  # NOTE: This is not used.
428
429
            num_decode_tokens=0,
            slot_mapping=slot_mapping,
430
            multi_modal_placeholder_index_maps=None,
431
            enable_kv_scales_calculation=False,
432
433
434
            block_tables=block_tables,
            context_lens=context_lens,
            effective_query_lens=prompt_lens,
435
        )
436
        return input_tokens, input_positions, attn_metadata, prompt_lens
437
438
439
440

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
441
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
442
443
444
445
446
447
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
        context_lens: List[int] = []

448
449
        batch_idx = 0
        for seq_group_metadata in seq_group_metadata_list:
450
451
452
453
454
455
456
457
458
459
460
461
462
463
            assert not seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
464
465
                self.block_tables[batch_idx, :len(block_table)] = block_table
                batch_idx += 1
466
467
468
469
470
471

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append([slot])

472
473
        batch_size = _get_padded_batch_size(batch_idx)
        num_paddings = batch_size - batch_idx
474
475
476
477
478
479
480
        input_tokens = input_tokens + [[0]] * num_paddings
        input_positions = input_positions + [[0]] * num_paddings
        slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
        context_lens = context_lens + [0] * num_paddings

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
481
                                    device="cpu")
482
483
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
484
                                       device="cpu")
485
486
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
487
                                    device="cpu")
488
489
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int32,
490
                                    device="cpu")
491
492
        block_tables = torch.tensor(self.block_tables[:batch_size],
                                    dtype=torch.int32,
493
                                    device="cpu")
494
495
        input_lens = torch.tensor([1] * batch_size,
                                  dtype=torch.int32,
496
                                  device="cpu")
497
498
499
500
501
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            slot_mapping=slot_mapping,
502
            multi_modal_placeholder_index_maps=None,
503
            enable_kv_scales_calculation=False,
504
505
506
            block_tables=block_tables,
            context_lens=context_lens,
        )
507
        return input_tokens, input_positions, attn_metadata, input_lens
508
509
510
511
512

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        padded_batch_size: int,
513
    ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
514
515
516
        assert len(seq_group_metadata_list) > 0
        t = []
        p = []
517
        n = []
518
519
        for seq_group_metadata in seq_group_metadata_list:
            sampling_params = seq_group_metadata.sampling_params
520
            t.append(sampling_params.temperature)
521
522
523
524
            if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
                raise NotImplementedError(
                    "Top-p sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
525
            p.append(sampling_params.top_p)
526
527
528
529
            if sampling_params.top_k != -1:
                raise NotImplementedError(
                    "Top-k sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
530
            if sampling_params.n > _MAX_NUM_SAMPLES:
531
                raise NotImplementedError(
532
                    f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
533
                    "backend.")
534
            n.append(sampling_params.n)
535
536
537
538
539
540
541
542
            if sampling_params.logprobs is not None:
                raise NotImplementedError(
                    "logprobs is not currently supported by the TPU backend.")
            if sampling_params.prompt_logprobs is not None:
                raise NotImplementedError(
                    "prompt_logprobs is not currently supported by the TPU "
                    "backend.")

543
544
545
546
            # Repeat the sampling params if the seq group has multiple seqs.
            num_seqs = len(seq_group_metadata.seq_data)
            t += [t[-1]] * (num_seqs - 1)
            p += [p[-1]] * (num_seqs - 1)
547
            n += [n[-1]] * (num_seqs - 1)
548
549

        num_paddings = padded_batch_size - len(t)
550
551
552
        t += [1.0] * num_paddings
        p += [1.0] * num_paddings

553
554
        t = torch.tensor(t, dtype=torch.float32, device="cpu")
        p = torch.tensor(p, dtype=torch.float32, device="cpu")
555
        return t, p, n
556

557
    def prepare_model_input(
558
        self,
559
        seq_group_metadata_list: List[SequenceGroupMetadata],
560
561
562
563
564
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None,
    ) -> ModelInputForTPU:
        del finished_requests_ids  # Unused.
        assert virtual_engine == 0
565
566
567
        assert len(seq_group_metadata_list) > 0
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
568
569
        is_prompt = seq_group_metadata_list[0].is_prompt
        if is_prompt:
570
571
572
            inputs = self._prepare_prompt(seq_group_metadata_list)
        else:
            inputs = self._prepare_decode(seq_group_metadata_list)
573
574
        input_tokens, input_positions, attn_metadata, input_lens = inputs
        padded_batch_size = input_tokens.shape[0]
575
576
        t, p, n = self._prepare_sample(seq_group_metadata_list,
                                       padded_batch_size)
577
        num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
578

579
580
581
582
583
        seq_groups = [
            list(metadata.seq_data.keys())
            for metadata in seq_group_metadata_list
        ]
        return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
584
                                input_lens, t, p, num_samples, n, seq_groups)
585
586
587
588
589
590
591

    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
        model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=self.attn_backend)
        return model_input

592
    @torch.no_grad()
593
594
595
    def execute_model(
        self,
        model_input: ModelInputForTPU,
596
        kv_caches: Optional[List[Any]],
597
598
599
600
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> List[SamplerOutput]:
        assert intermediate_tensors is None
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
        if not model_input.is_first_multi_step:
            if not model_input.is_last_step:
                return []

            use_async_out_proc = model_input.async_callback is not None
            sampler_outputs = []
            num_outputs = len(self.cached_step_outputs)
            for i in range(num_outputs):
                next_token_ids = self.cached_step_outputs.pop(0)
                next_token_ids = next_token_ids.cpu().tolist()
                sampler_output = _make_decode_output(next_token_ids,
                                                     model_input.seq_groups)
                sampler_outputs.append(sampler_output)

                if i < num_outputs - 1 and use_async_out_proc:
                    assert model_input.async_callback is not None
                    ctx = model_input.async_callback.keywords[  # type: ignore
                        "ctx"]
                    ctx.append_output(
                        outputs=[sampler_output],
                        seq_group_metadata_list=ctx.seq_group_metadata_list,
                        scheduler_outputs=ctx.scheduler_outputs,
                        is_async=False,
624
625
                        is_last_step=False,
                        is_first_step_output=i == 0)
626
627
628
629
630
631
632
                    model_input.async_callback()
            if use_async_out_proc:
                return [sampler_outputs[-1]]
            else:
                return sampler_outputs

        is_prompt = model_input.attn_metadata.num_prefills > 0
633
        if is_prompt:
634
            assert num_steps == 1
635
636
637
638
639
            # NOTE(woosuk): Since the FlashAttention kernel does not support
            # ragged inputs, we split the prompts into different batches and
            # process them separately. This is a temporary hack that should be
            # optimized by using SplashAttention.
            orig_slot_mapping = model_input.attn_metadata.slot_mapping
640
641
642
643
            orig_block_tables = model_input.attn_metadata.block_tables
            orig_context_lens = model_input.attn_metadata.context_lens
            orig_effective_query_lens = \
                model_input.attn_metadata.effective_query_lens
644
645
            batch_size = model_input.input_lens.shape[0]
            start_idx = 0
646
            next_token_ids = []
647
648
649
650
651
652
            for i in range(batch_size):
                # Get the actual prefill_len.
                prefill_len = model_input.input_lens[i:i + 1].item()
                prefill_len = _get_padded_prefill_len(prefill_len)
                end_idx = start_idx + prefill_len

653
654
655
656
657
658
659
660
661
                token_ids = model_input.token_ids[None, start_idx:end_idx].to(
                    self.device)
                position_ids = model_input.position_ids[None,
                                                        start_idx:end_idx].to(
                                                            self.device)
                attn_metadata = model_input.attn_metadata
                attn_metadata.num_prefills = 1
                attn_metadata.slot_mapping = orig_slot_mapping[
                    None, start_idx:end_idx].to(self.device)
662
663
664
665
666
667
668
669
670
671
672
                if orig_context_lens[i].item() > 0:
                    attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
                        self.device)
                    attn_metadata.block_tables = orig_block_tables[
                        i].unsqueeze(0).to(self.device)
                    attn_metadata.effective_query_lens = \
                        orig_effective_query_lens[i:i + 1].to(self.device)
                else:
                    attn_metadata.context_lens = None
                    attn_metadata.block_tables = None
                    attn_metadata.effective_query_lens = None
673
674
675
                input_lens = model_input.input_lens[i:i + 1].to(self.device)
                t = model_input.t[i:i + 1].to(self.device)
                p = model_input.p[i:i + 1].to(self.device)
676
677
678
679
680
681
682
                with set_forward_context(model_input.attn_metadata,
                                         self.vllm_config,
                                         model_input.virtual_engine):
                    output_token_ids = self.model(token_ids, position_ids,
                                                  attn_metadata, input_lens, t,
                                                  p, model_input.num_samples,
                                                  kv_caches)
683
                next_token_ids.append(output_token_ids[0])
684
                start_idx = end_idx
685

686
687
            if model_input.async_callback is not None:
                model_input.async_callback()
688
            # Retrieve the outputs to CPU.
689
690
691
692
693
694
695
696
697
698
699
700
            next_token_ids = [
                output_token_ids.cpu().tolist()
                for output_token_ids in next_token_ids
            ]

            # NOTE(woosuk): Minimal code to construct the sampler outputs.
            # The TPU backend does not reuse the sampler, since the TPU backend
            # does not support advanced sampling parameters such as logprobs.
            zero_logprob = Logprob(0.0)
            sampler_outputs = []
            for i, seq_group in enumerate(model_input.seq_groups):
                seq_ids = seq_group
701
702
                assert len(seq_ids) == 1
                seq_id = seq_ids[0]
703
                seq_outputs = []
704
                for j in range(model_input.n[i]):
705
                    next_token_id = next_token_ids[i][j]
706
707
708
                    seq_outputs.append(
                        SequenceOutput(seq_id, next_token_id,
                                       {next_token_id: zero_logprob}))
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
                sampler_outputs.append(
                    CompletionSequenceGroupOutput(seq_outputs, None))
            return [SamplerOutput(sampler_outputs)]
        else:
            token_ids = model_input.token_ids.to(self.device)
            position_ids = model_input.position_ids.to(self.device)
            attn_metadata = model_input.attn_metadata
            attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
                self.device)
            attn_metadata.block_tables = attn_metadata.block_tables.to(
                self.device)
            attn_metadata.context_lens = attn_metadata.context_lens.to(
                self.device)
            t = model_input.t.to(self.device)
            p = model_input.p.to(self.device)
            input_lens = model_input.input_lens.to(self.device)
            for i in range(num_steps):
                slot_mapping = attn_metadata.slot_mapping
727
728
729
730
731
732
733
                with set_forward_context(model_input.attn_metadata,
                                         self.vllm_config,
                                         model_input.virtual_engine):
                    output_token_ids = self.model(token_ids, position_ids,
                                                  attn_metadata, input_lens, t,
                                                  p, model_input.num_samples,
                                                  kv_caches)
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
                self.cached_step_outputs.append(output_token_ids)

                if i < num_steps - 1:
                    # Prepare the inputs for the next step.
                    token_ids = output_token_ids.unsqueeze(dim=1).int()
                    position_ids = position_ids + 1
                    attn_metadata.context_lens = attn_metadata.context_lens + 1

                    block_tables = attn_metadata.block_tables
                    block_number = block_tables.gather(
                        1,
                        position_ids.long() // self.block_size)
                    block_offset = position_ids % self.block_size

                    is_padding = slot_mapping == _PAD_SLOT_ID
                    slot_mapping = block_number * self.block_size + block_offset
                    slot_mapping = slot_mapping.long()
                    slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
                                               slot_mapping)
                    attn_metadata.slot_mapping = slot_mapping

            if model_input.async_callback is not None:
                model_input.async_callback()

            if num_steps > 1:
                return []
            # Retrieve the outputs to CPU.
            next_token_ids = self.cached_step_outputs.pop(0)
            next_token_ids = next_token_ids.cpu().tolist()
            sampler_output = _make_decode_output(next_token_ids,
                                                 model_input.seq_groups)
            return [sampler_output]
766
767


768
class ModelWrapper(nn.Module):
769

770
771
    def __init__(self, model: nn.Module):
        super().__init__()
772
        self.model = model
773
774
775
776
777
778
779
780
781

    def forward(
        self,
        token_ids: torch.Tensor,
        position_ids: torch.Tensor,
        attn_metadata: AttentionMetadata,
        input_lens: torch.Tensor,
        t: torch.Tensor,
        p: torch.Tensor,
782
        num_samples: int,
783
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
784
785
786
787
788
789
790
791
792
793
    ) -> torch.Tensor:
        """Executes the forward pass of the model and samples the next token.

        Args:
            token_ids: The input token IDs of shape [batch_size, seq_len].
            position_ids: The input position IDs of shape [batch_size, seq_len].
            attn_metadata: The Pallas attention metadata.
            input_lens: The actual input lengths of shape [batch_size].
            t: The sampling temperature of shape [batch_size].
            p: The top-p probability of shape [batch_size].
794
795
796
            num_samples: Number of samples to draw from each logits vector.
            kv_caches: The key and value caches. They can be None during the
                memory profiling at initialization.
797
798
799
        """
        batch_size, seq_len = token_ids.shape
        # Calculate the positions to sample from.
800
        start_indicies = torch.arange(
801
            batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
802
        logits_indices = start_indicies + input_lens - 1
803
804
805
806
807
808
809
810
811
812
813

        # FIXME(woosuk): This is a temporary hack to avoid using the existing
        # sampler and sampling metadata.
        sampling_metadata = SamplingMetadata(
            seq_groups=[],
            selected_token_indices=logits_indices,
            categorized_sample_indices={},
            num_prompts=attn_metadata.num_prefills,
        )

        # Skip this in memory profiling at initialization.
814
        if kv_caches[0][0].numel() > 0:
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
            # index_copy_(slot_mapping) only works when the inserted dimension
            # is 0. However, the KV cache in the Pallas backend has the shape
            # [num_kv_heads, num_blocks, block_size, head_size]. To make it
            # work, we need to flatten the first three dimensions and modify
            # the slot_mapping accordingly.
            num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
            slot_mapping = attn_metadata.slot_mapping
            slot_mapping = slot_mapping.flatten()
            head_indicies = torch.arange(0,
                                         num_kv_heads,
                                         device=slot_mapping.device,
                                         dtype=slot_mapping.dtype)
            head_indicies *= block_size * num_blocks
            slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
                -1, num_kv_heads)
            slot_mapping = slot_mapping + head_indicies.view(1, -1)
            slot_mapping = slot_mapping.flatten()
            attn_metadata.slot_mapping = slot_mapping

        hidden_states = self.model(
            token_ids,
            position_ids,
            kv_caches,
            attn_metadata,
        )
        hidden_states = hidden_states.flatten(0, 1)
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

843
844
845
846
847
848
849
        # Argmax sampling.
        argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
        argmax_token_ids = argmax_token_ids.repeat(1, num_samples)

        # Zero temperature means greedy decoding. Avoid division by zero.
        nonzero_t = torch.where(t != 0, t, 1.0)
        logits = logits / nonzero_t.unsqueeze(dim=1)
850
851
        if _ENABLE_TOP_P:
            logits = _apply_top_p(logits, p.unsqueeze(dim=1))
852
853

        # Random sampling.
854
        probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
855
856
857
        sampled_token_ids = torch.multinomial(probs,
                                              num_samples,
                                              replacement=True)
858
859
860
        if num_samples == 1:
            argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
            sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
861
862
        next_token_ids = torch.where(t != 0, sampled_token_ids,
                                     argmax_token_ids)
863
864
865
866
867
868
869
870
871
872
873
874
875
        return next_token_ids


def _get_padded_prefill_len(x: int) -> int:
    # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
    # length to be a multiple of 16. We pad the prompt length to the nearest
    # multiple of 16. This is also good for performance.
    if x <= 16:
        return 16
    return 1 << (x - 1).bit_length()


def _get_padded_batch_size(batch_size: int) -> int:
876
877
878
879
    # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
    # To meet this requirement in the simplest way, we set the minimal batch
    # size to 8.
    if batch_size <= 8:
880
881
882
883
884
885
886
887
888
889
890
891
        return 8
    else:
        return ((batch_size + 15) // 16) * 16


def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
    logits_sorted = torch.sort(logits, dim=-1, descending=True).values
    sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
    cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
    cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
    logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
    return logits
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912


def _make_decode_output(
    next_token_ids: List[int],
    seq_groups: List[List[int]],
) -> SamplerOutput:
    zero_logprob = Logprob(0.0)
    sampler_outputs = []
    batch_idx = 0
    for seq_group in seq_groups:
        seq_ids = seq_group
        seq_outputs = []
        for seq_id in seq_ids:
            next_token_id = next_token_ids[batch_idx]
            seq_outputs.append(
                SequenceOutput(seq_id, next_token_id,
                               {next_token_id: zero_logprob}))
            batch_idx += 1
        sampler_outputs.append(CompletionSequenceGroupOutput(
            seq_outputs, None))
    return SamplerOutput(sampler_outputs)