tpu_model_runner.py 47.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import bisect
3
import time
4
from typing import TYPE_CHECKING, Optional, cast
5
6
7
8
9
10
11
12
13
14
from unittest.mock import patch

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

15
import vllm.envs as envs
16
17
18
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
19
from vllm.forward_context import set_forward_context
20
21
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
22
23
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
24
from vllm.multimodal.utils import group_mm_inputs_by_modality
25
from vllm.sampling_params import SamplingType
26
from vllm.sequence import IntermediateTensors
27
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
28
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
29
                                               PallasMetadata)
30
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
31
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
32
                                        KVCacheSpec, SlidingWindowSpec)
33
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
34
35
36
                             ModelRunnerOutput, SamplerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
37
38
39
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

40
41
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
                    scatter_mm_placeholders)
42

43
if TYPE_CHECKING:
44
    from vllm.v1.core.sched.output import SchedulerOutput
45
46
47
48
49
50

logger = init_logger(__name__)

# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
51
INVALID_TOKEN_ID = -1
52
53
# Smallest output size
MIN_NUM_SEQS = 8
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


class TPUModelRunner:

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
        self.device_config = vllm_config.device_config

        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
        self.device = device
80
        self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
81

82
        self.enforce_eager = model_config.enforce_eager
83
84
85
86

        self.num_xla_graphs = 0
        self._update_num_xla_graphs("init")

87
88
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
89
        self._hidden_states_dtype = self.dtype
90
91
92
93
94
95

        self.is_multimodal_model = model_config.is_multimodal_model
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
96
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
97
98
        # InputBatch needs to work with sampling tensors greater than padding
        # to avoid dynamic shapes. Also, avoid suboptimal alignment.
99
        self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
100
101
102
103
104
105
106
107
108
109

        # Model-related.
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
        self.hidden_size = model_config.get_hidden_size()

110
111
112
113
114
115
116
117
118
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.uses_mrope = model_config.uses_mrope
        # TODO: Support M-RoPE (e.g, Qwen2-VL)
        assert not self.uses_mrope, "TPU does not support M-RoPE yet."

        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
119
            mm_registry=self.mm_registry,
120
121
122
123
124
125
126
127
128
129
130
131
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
        self.kv_caches: list[torch.Tensor] = []
        # req_id -> (input_id -> encoder_output)
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

        # Request states.
        self.requests: dict[str, CachedRequestState] = {}
132
133
134
135
136
137
138
        # Persistent batch.
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
139
            vocab_size=model_config.get_vocab_size(),
140
141
        )

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        # Cached torch/numpy tensor
        # The pytorch tensor and numpy array share the same buffer.
        # Sometimes the numpy op is faster so we create both.
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu")
        self.input_ids_np = self.input_ids_cpu.numpy()

        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu")
        self.positions_np = self.positions_cpu.numpy()

        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int64,
                                            device="cpu")
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.block_table_cpu = torch.zeros(
160
            (self.max_num_tokens, self.max_num_blocks_per_req),
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
            device="cpu")

        self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()

        self.seq_lens_cpu = torch.zeros(self.max_num_tokens,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
175
176
177
178

        # Range tensor with values [0 .. self.max_num_tokens - 1].
        # Used to initialize positions / context_lens / seq_lens
        self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
179
180
181
182
        self.num_tokens_paddings = _get_paddings(
            min_token_size=16,
            max_token_size=self.max_num_tokens,
            padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
183

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def _update_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        total_cached_graphs = xr.get_num_cached_compilation_graph()
        new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
        if new_compiled_graphs == 0:
            return

        logger.info("Add new %d compiled XLA graphs due to %s",
                    new_compiled_graphs, case_str)
        self.num_xla_graphs += new_compiled_graphs

    def _verify_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        curr_cached_graph = xr.get_num_cached_compilation_graph()
        assert self.num_xla_graphs == curr_cached_graph, (
            "Recompilation after warm up is detected during {}."
            " num_xla_graphs = {} curr_cached_graph = {}".format(
                case_str, self.num_xla_graphs, curr_cached_graph))

209
210
211
212
213
214
215
216
    def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        Returns:
217
            True if there is a new/resumed/paused/finished request.
218
219
220
221
222
            If False, we can skip copying SamplingMetadata to the GPU.
        """
        # Remove finished requests from the cached states.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
223
            self.encoder_cache.pop(req_id, None)
224
225
226
227
228
229
230

        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
231
        removed_req_indices: list[int] = []
232
233
234
235
236
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)

237
238
239
240
241
242
243
244
        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            assert req_index is not None
            removed_req_indices.append(req_index)

262
        req_ids_to_add: list[str] = []
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        # Add new requests to the cached states.
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
                sampling_params=sampling_params,
                generator=generator,
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                output_token_ids=[],
                lora_request=new_req_data.lora_request,
            )

            req_ids_to_add.append(req_id)

        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
            req_state = self.requests[req_id]

            # Update the cached states.
            req_state.num_computed_tokens = req_data.num_computed_tokens
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
                req_data.num_computed_tokens)
315
316
            self.input_batch.block_table.append_row(req_data.new_block_ids,
                                                    req_index)
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
334

335
336
337
338
339
340
        return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0

    def get_model(self) -> nn.Module:
        assert self.model is not None
        return self.model

341
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
342
        """
343
        Generates the KVCacheSpec by parsing the kv cache format from each
344
345
        Attention module in the static forward context.
        Returns:
346
            KVCacheSpec: A dictionary mapping layer names to their KV cache
347
348
349
350
351
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
352
        kv_cache_spec: dict[str, KVCacheSpec] = {}
353
354
355
        for layer_name, attn_module in forward_ctx.items():
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=attn_module.dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=False,
                    )
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=attn_module.dtype,
                        use_mla=False,
                    )
373
374
375
376
377
378
379
380
381
382
383
384
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec

385
    def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
386
387
388
389
390
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

391
392
393
394
        # Get the number of scheduled tokens for each request.
        num_scheduled_tokens_per_req = []
        max_num_scheduled_tokens_all_reqs = 0
        for req_id in self.input_batch.req_ids[:num_reqs]:
395
            assert req_id is not None
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
            num_scheduled_tokens_per_req.append(num_tokens)
            max_num_scheduled_tokens_all_reqs = max(
                max_num_scheduled_tokens_all_reqs, num_tokens)
        num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
                                                dtype=np.int32)
        assert max_num_scheduled_tokens_all_reqs > 0

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        # For each scheduled token, what are the corresponding req index.
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens_per_req)

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # For each scheduled token, what is its position in corresponding req.
        arange = np.concatenate(
            [self.arange_np[:n] for n in num_scheduled_tokens_per_req])

        # Get positions.
        positions_np = self.positions_np[:total_num_scheduled_tokens]
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])

        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
432
433
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
                           0,
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])

        # Calculate the slot mapping.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
        # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
449
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
450
451
452
453
454
455
456
457
458
459
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])

        # Prepare the attention metadata.
        self.query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens_per_req,
                  out=self.query_start_loc_np[1:num_reqs + 1])
460
        self.query_start_loc_np[num_reqs + 1:] = 1
461
462
463
464
465
466

        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens_per_req)

        # Do the padding and copy the tensors to the TPU.
467
        padded_total_num_scheduled_tokens = _get_padded_token_len(
468
            self.num_tokens_paddings, total_num_scheduled_tokens)
469
470
471
        # Zero out to avoid spurious values from prev iteration (last cp chunk)
        self.input_ids_cpu[
            total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
472
473
474
475
476
477
478
479
480
481
        self.input_ids = self.input_ids_cpu[:
                                            padded_total_num_scheduled_tokens].to(
                                                self.device)
        self.position_ids = self.positions_cpu[:
                                               padded_total_num_scheduled_tokens].to(
                                                   self.device)
        self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID
        slot_mapping = self.slot_mapping_cpu[:
                                             padded_total_num_scheduled_tokens].to(
                                                 self.device)
482
483
        block_tables = self.block_table_cpu[:self.max_num_reqs]
        block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
484
            self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
485
486
        block_tables = block_tables.to(self.device)
        query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
487
            self.device)
488
        seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
489
490
491

        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
492
            block_tables=block_tables,
493
494
            context_lens=seq_lens,
            query_start_loc=query_start_loc,
495
496
497
            num_seqs=torch.tensor([num_reqs],
                                  dtype=torch.int32,
                                  device=self.device),
498
        )
499
500
501
502
503
        # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
        # request in the batch. While we should not sample any token from this
        # partial request, we do so for simplicity. We will ignore the sampled
        # token from the partial request.
        # TODO: Support prompt logprobs.
504
505
        padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
            num_reqs, self.max_num_reqs)
506
507
        # Indices at which we sample (positions of last token in the sequence).
        # Padded to avoid recompiling when `num_reqs` varies.
508
509
        logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
        logits_indices = logits_indices.to(self.device)
510
        return attn_metadata, logits_indices
511

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    def _scatter_placeholders(
        self,
        embeds: torch.Tensor,
        is_embed: Optional[torch.Tensor],
    ) -> torch.Tensor:
        if is_embed is None:
            return embeds

        placeholders = embeds.new_full(
            (is_embed.shape[0], embeds.shape[-1]),
            fill_value=torch.nan,
        )
        placeholders[is_embed] = embeds
        return placeholders

    def _gather_placeholders(
        self,
        placeholders: torch.Tensor,
        is_embed: Optional[torch.Tensor],
    ) -> torch.Tensor:
        if is_embed is None:
            return placeholders

        return placeholders[is_embed]

    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
538
539
540
541
542
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
543
544
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
545
546
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
547
548
549
550
551

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

578
579
580
581
582
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

583
584
585
586
            for output in curr_group_outputs:
                encoder_outputs.append(output)

        # Cache the encoder outputs.
587
588
589
590
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
591
592
593
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}

594
595
596
597
598
599
            self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
                output,
                is_embed=pos_info.is_embed,
            )

    def _gather_mm_embeddings(
600
601
602
        self,
        scheduler_output: "SchedulerOutput",
    ) -> list[torch.Tensor]:
603
        mm_embeds: list[torch.Tensor] = []
604
605
606
607
608
609
610
        for req_id in self.input_batch.req_ids:
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
611
612
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
634
635
636
637
638
639
640
641
642
643

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                mm_embeds.append(mm_embeds_item)
        return mm_embeds
644

645
646
647
648
    @torch.no_grad()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
649
        intermediate_tensors: Optional[IntermediateTensors] = None,
650
651
652
    ) -> ModelRunnerOutput:
        # Update cached state
        self._update_states(scheduler_output)
653
654
655
        if not scheduler_output.total_num_scheduled_tokens:
            # Return empty ModelRunnerOuptut if there's no work to do.
            return EMPTY_MODEL_RUNNER_OUTPUT
656

657
658
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
659
660
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
661
        else:
662
            mm_embeds = []
663

664
665
        # Prepare inputs
        attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
666
667
668
669
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
670
            if mm_embeds:
671
                inputs_embeds = self.model.get_input_embeddings(
672
                    self.input_ids, mm_embeds)
673
674
675
676
677
678
679
680
681
682
            else:
                inputs_embeds = self.model.get_input_embeddings(self.input_ids)
            input_ids = None
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids
            inputs_embeds = None
683
        num_reqs = self.input_batch.num_reqs
684
685
686
        # NOTE (NickLucche) here we sync with TPU: sampling params tensors
        # are copied to device in chunks of pre-compiled padded shape to
        # avoid recompilations.
687
        tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
688
            from_input_batch(self.input_batch, logits_indices)
689
690
691
        # Run the decoder
        with set_forward_context(attn_metadata, self.vllm_config):
            hidden_states = self.model(
692
693
                input_ids=input_ids,
                positions=self.position_ids,
694
                kv_caches=self.kv_caches,
695
                inputs_embeds=inputs_embeds,
696
            )
697
698
699
        selected_token_ids = self.model.sample_from_hidden(
            hidden_states, tpu_sampling_metadata)
        # Remove padding on cpu and keep dynamic op outside of xla graph.
700
        selected_token_ids = selected_token_ids.cpu()[:num_reqs]
701

702
703
        # Update the cache state concurrently. Code above will not block until
        # we use `selected_token_ids`. Add mark_step if post-processing changes
704
        request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
705
        discard_sampled_tokens_req_indices = []
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
            assert req_id is not None
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
            if seq_len >= req_state.num_tokens:
                request_seq_lens.append((i, req_state, seq_len))
            else:
                # Ignore the sampled token from the partial request.
                # Rewind the generator state as if the token was not sampled.
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)

721
722
723
724
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

725
726
727
        assert all(
            req_id is not None for req_id in
            self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
728
        req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
729

730
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
731
        for req_id in self.input_batch.req_ids[:num_reqs]:
732
733
            prompt_logprobs_dict[req_id] = None

734
735
736
        max_gen_len = selected_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_sampled_token_ids = selected_token_ids.tolist()
737

738
739
740
741
742
743
744
            # Mask out the sampled tokens that should not be sampled.
            # TODO: Keep in sync with gpu_model_runner.py, in particular
            #       the "else" case here
            for i in discard_sampled_tokens_req_indices:
                valid_sampled_token_ids[i].clear()

            # Append sampled tokens
745
746
747
748
749
            for i, req_state, seq_len in request_seq_lens:
                token_id = valid_sampled_token_ids[i][0]
                self.input_batch.token_ids_cpu[i, seq_len] = token_id
                req_state.output_token_ids.append(token_id)
                self.input_batch.num_tokens[i] += 1
750

751
752
753
754
755
756
757
758
759
760
761
762
763
764
        else:
            valid_mask = selected_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
            valid_sampled_token_ids = [
                seq.tolist()
                for seq in selected_token_ids[valid_mask].split(gen_lens)
            ]
            self.input_batch.num_tokens[:num_reqs] += gen_lens
            for i, req_state, seq_len in request_seq_lens:
                target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
                self.input_batch.token_ids_cpu[
                    i, target_slice] = valid_sampled_token_ids[i]
                req_state.output_token_ids.extend(valid_sampled_token_ids[i])

765
        model_runner_output = ModelRunnerOutput(
766
            req_ids=req_ids,
767
            req_id_to_index=self.input_batch.req_id_to_index,
768
            sampled_token_ids=valid_sampled_token_ids,
769
            spec_token_ids=None,
770
            logprobs=None,
771
            prompt_logprobs_dict=prompt_logprobs_dict,
772
        )
773
774
775
776
777

        # Check there are no new graphs compiled - all the graphs should be
        # captured and compiled during warm up.
        self._verify_num_xla_graphs("execute_model")

778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
        return model_runner_output

    def load_model(self) -> None:
        self.device = self.device_config.device

        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
        xm_tp_rank = xr.global_ordinal()
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
            model = get_model(vllm_config=self.vllm_config)
        model = model.eval()
        xm.mark_step()
        xm.wait_device_ops()
        model = ModelWrapperV1(model)
        self.model = torch.compile(model,
                                   backend="openxla",
                                   fullgraph=True,
                                   dynamic=False)

807
808
    @torch.no_grad()
    def _dummy_run(self, kv_caches, num_tokens: int) -> None:
809
810
811
812
813
814
815
816
817
818
        if self.is_multimodal_model:
            input_ids = None
            inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
                                        dtype=self.dtype,
                                        device=self.device)
        else:
            input_ids = torch.zeros((num_tokens),
                                    dtype=torch.int32,
                                    device=self.device)
            inputs_embeds = None
819
        actual_num_reqs = min(num_tokens, self.max_num_reqs)
820
821
822
823
824
825
        position_ids = torch.zeros(num_tokens,
                                   dtype=torch.int32,
                                   device=self.device)
        slot_mapping = torch.zeros(num_tokens,
                                   dtype=torch.int64,
                                   device=self.device)
826
827
828
829
830
        block_tables = torch.zeros(
            (self.max_num_reqs, self.block_table_cpu.shape[1]),
            dtype=torch.int32,
            device=self.device)
        query_lens = [1] * self.max_num_reqs
831
832
833
834
        query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
                                                    dtype=torch.int32),
                                       dim=0,
                                       dtype=torch.int32).to(self.device)
835
        context_lens = torch.ones((self.max_num_reqs, ),
836
837
                                  dtype=torch.int32,
                                  device=self.device)
838
839
840
        num_seqs = torch.tensor([actual_num_reqs],
                                dtype=torch.int32,
                                device=self.device)
841
842
843
844
845
        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            context_lens=context_lens,
            query_start_loc=query_start_loc,
846
            num_seqs=num_seqs,
847
        )
848

849
850
851
852
        if self.is_multimodal_model:
            torch._dynamo.mark_dynamic(inputs_embeds, 0)
        else:
            torch._dynamo.mark_dynamic(input_ids, 0)
853
854
        torch._dynamo.mark_dynamic(position_ids, 0)
        torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
855
856

        with set_forward_context(attn_metadata, self.vllm_config, 0):
857
858
859
860
861
            out = self.model(input_ids=input_ids,
                             positions=position_ids,
                             kv_caches=kv_caches,
                             inputs_embeds=inputs_embeds)
        self._hidden_states_dtype = out.dtype
862
863
864
865

    def capture_model(self) -> None:
        """Compile the model."""

866
867
868
        logger.info("Compiling the model with different input shapes.")

        start = time.perf_counter()
869
        for num_tokens in self.num_tokens_paddings:
870
            logger.info("  -- num_tokens: %d", num_tokens)
871
            self._dummy_run(self.kv_caches, num_tokens)
872
            xm.mark_step()
873
874
        xm.wait_device_ops()
        end = time.perf_counter()
875

876
        logger.info("Compilation finished in in %.2f [secs].", end - start)
877
        self._update_num_xla_graphs("model")
878
879
880
881
882
883
884

        logger.info("Compiling sampling with different input shapes.")
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
        device = self.device
        # Compile sampling step for different model+sampler outputs in bucketed
        # n_tokens x max_num_reqs. Graph is really small so this is fine.
885
        for num_tokens in self.num_tokens_paddings:
886
887
888
            num_reqs_to_sample = MIN_NUM_SEQS
            dummy_hidden = torch.randn((num_tokens, hsize),
                                       device=device,
889
                                       dtype=self._hidden_states_dtype)
890
            # Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
891
892
893
894
895
896
            while True:
                indices = torch.zeros(
                    num_reqs_to_sample,
                    dtype=torch.int32,
                    device=device,
                )
897
                xm.mark_step()
898
                sampling_meta = TPUSupportedSamplingMetadata.\
899
                    from_input_batch(self.input_batch, indices)
900
901
                logger.info("  -- num_tokens: %d, num_seqs: %d", num_tokens,
                            num_reqs_to_sample)
902
903
904
                out = self.model.sample_from_hidden(dummy_hidden,
                                                    sampling_meta)
                out = out.cpu()
905
906
907
                # Requests can't be more than tokens. But do compile for the
                # next bigger value in case num_tokens uses bucketed padding.
                if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
908
                    break
909
910
911
                # Make sure to compile the `max_num_reqs` upper-limit case
                num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
                    num_reqs_to_sample + 1, self.max_num_reqs)
912
        xm.wait_device_ops()
913
        end = time.perf_counter()
914
915
916

        logger.info("Compilation finished in in %.2f [secs].", end - start)
        self._update_num_xla_graphs("sampling")
917
918
919
920
921

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
922
            kv_cache_config: Configuration for the KV cache, including the KV
923
924
            cache size of each layer
        """
925
        if len(kv_cache_config.kv_cache_groups) > 1:
926
927
928
929
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

930
        kv_caches: dict[str, torch.Tensor] = {}
931

932
933
934
935
936
937
938
939
940
941
942
943
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
                tensor_config = kv_cache_config.tensors[layer_name]
                assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
                if isinstance(kv_cache_spec, FullAttentionSpec):
                    kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype

944
945
946
                    tpu_kv_cache = torch.zeros(kv_cache_shape,
                                               dtype=dtype,
                                               device=self.device)
947

948
                    kv_caches[layer_name] = tpu_kv_cache
949
950
                else:
                    raise NotImplementedError
951
952
953
954
955
956
957
958
959
960
961
962

        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
            self.kv_caches)


class ModelWrapperV1(nn.Module):

    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
963
964
965
966
967
968
969
        self.sampler = TPUSampler()

    def sample(
            self, logits: torch.Tensor,
            sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
        sampler_out = self.sampler(logits, sampling_metadata)
        return sampler_out
970
971
972

    def forward(
        self,
973
974
        input_ids: torch.Tensor,
        positions: torch.Tensor,
975
        kv_caches: list[torch.Tensor],
976
        inputs_embeds: Optional[torch.Tensor] = None,
977
    ) -> torch.Tensor:
978
        """Executes the forward pass of the model.
979
980

        Args:
981
982
            input_ids: The input token IDs of shape [num_tokens].
            positions: The input position IDs of shape [num_tokens].
983
984
            kv_caches: The key and value caches. They can be None during the
                memory profiling at initialization.
985
986
            inputs_embeds: The input embeddings of shape [num_tokens,
                hidden_size]. It is used for multimodal models.
987
988
        """

989
        hidden_states = self.model(
990
991
992
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
993
        )
994

995
        return hidden_states
996

997
    def sample_from_hidden(
998
999
        self,
        hidden_states: torch.Tensor,
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        sampling_metadata: TPUSupportedSamplingMetadata,
    ) -> torch.Tensor:
        """
        Sample with xla-friendly function. This function is to be traced 
        separately from `forward` for lighter compilation overhead.
        """
        # Tensor `sample_hidden_states` is of fixed pre-compiled size.
        sample_hidden_states = \
            hidden_states[sampling_metadata.indices_do_sample]
        logits = self.compute_logits(sample_hidden_states)
1010
1011
1012
        # Optimized greedy sampling branch, tracing both paths in a single pass
        # NOTE all_greedy is a scalar, this is just an optimized if/else.
        out_tokens = torch.where(sampling_metadata.all_greedy,
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
                        torch.argmax(logits, dim=-1, keepdim=True),
                        self.sample(logits, sampling_metadata)\
                                            .sampled_token_ids)
        return out_tokens

    def compute_logits(self,
                       hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
        # SamplingMetadata here for pruning output in LogitsProcessor, disabled
        logits = self.model.compute_logits(hidden_states, None)
        return logits
1023

1024
1025
1026
1027
1028
1029
    def get_multimodal_embeddings(self, *args, **kwargs):
        return self.model.get_multimodal_embeddings(*args, **kwargs)

    def get_input_embeddings(self, *args, **kwargs):
        return self.model.get_input_embeddings(*args, **kwargs)

1030

1031
1032
def _get_padded_number(n: int, multiple: int) -> int:
    return ((n + multiple - 1) // multiple) * multiple
1033
1034


1035
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
1036
    res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
1037
    return min(res, upper_limit)
1038
1039
1040
1041
1042
1043


def _get_paddings(min_token_size: int, max_token_size: int,
                  padding_gap: int) -> list[int]:
    """Generate a list of padding size, starting from min_token_size, 
    ending with a number that can cover max_token_size
1044
1045
1046
1047
1048
1049
    
    If padding_gap == 0 then:
        increase 2X each time (exponential)
    else:
        first increase the size to twice, 
        then increase the padding size by padding_gap.
1050
1051
1052
    """
    paddings = []
    num = min_token_size
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072

    if padding_gap == 0:
        logger.info("Using exponential paddings:")
        while num <= max_token_size:
            logger.info("    %d", num)
            paddings.append(num)
            num *= 2

    else:
        logger.info("Using incremental paddings:")
        while num <= padding_gap:
            logger.info("    %d", num)
            paddings.append(num)
            num *= 2
        num //= 2
        while num < max_token_size:
            num += padding_gap
            logger.info("    %d", num)
            paddings.append(num)

1073
1074
1075
1076
1077
1078
1079
1080
1081
    return paddings


def _get_padded_token_len(paddings: list[int], x: int) -> int:
    """Return the first element in paddings list greater or equal to x.
    """
    index = bisect.bisect_left(paddings, x)
    assert index < len(paddings)
    return paddings[index]