tpu_model_runner.py 88.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import bisect
4
import gc
5
import time
6
from typing import TYPE_CHECKING, Any, Optional, cast
7
8
9
10
11
12
13
from unittest.mock import patch

import numpy as np
import torch
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
14
import torch_xla.distributed.spmd as xs
15
16
import torch_xla.runtime as xr

17
import vllm.envs as envs
18
19
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
20
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
21
22
from vllm.config import (ParallelConfig, VllmConfig,
                         get_layers_from_vllm_config, update_config)
23
from vllm.forward_context import set_forward_context
24
from vllm.logger import init_logger
25
from vllm.lora.layers import BaseLayerWithLoRA
26
from vllm.model_executor.model_loader import get_model_loader
27
from vllm.model_executor.model_loader.tpu import TPUModelLoader
28
from vllm.model_executor.models.interfaces_base import is_pooling_model
29
from vllm.multimodal import MULTIMODAL_REGISTRY
30
31
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
                                    PlaceholderRange)
32
from vllm.multimodal.utils import group_mm_inputs_by_modality
33
from vllm.pooling_params import PoolingTask
34
from vllm.sequence import IntermediateTensors
35
36
37
38
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
                        prev_power_of_2)
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
                                               PallasAttentionBackend,
39
40
                                               PallasMetadata,
                                               get_page_size_bytes)
41
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
42
43
44
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
                                        KVCacheConfig, KVCacheSpec,
                                        SlidingWindowSpec)
45
46
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
                             LogprobsTensors, ModelRunnerOutput)
47
48
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
49
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
50
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
51

52
from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
53
                    sanity_check_mm_encoder_outputs)
54

55
if TYPE_CHECKING:
56
    from vllm.v1.core.sched.output import SchedulerOutput
57
58
59

logger = init_logger(__name__)

60
INVALID_TOKEN_ID = -1
61
62
# Smallest output size
MIN_NUM_SEQS = 8
63
64


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#########################################################
# Ways to avoid recompilation
#########################################################
#
# The model executor has two primary components:
# 1. preparing the model and sampler inputs
# 2. executing the model and sampler.
# The core idea is to avoid any TPU computation during input preparation. For
# better compilation tracking and increased flexibility, the model execution and
# sampler are divided into several distinct components.
#
# Below are the detailed steps:
#
# Step 1
# It is recommended to avoid TPU operations when preparing the model and sampler
# inputs. CPU tensors can be prepared and transferred to the XLA device using
# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
# compilation.
#
# Step 2
# The TPU execution should be decomposed into subgraphs (4 at the moment):
# 1. the main model
# 2. selecting hidden states for each request
# 3. sampler
# 4. encoder.
# Each subgraph should be decorated in a torch.compile. This is used to make
# sure that we have the same subgraph topology in both dummy_run and
# xecute_model. The results from these subgraphs should either be passed to
# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
# subsequent processing on the CPU.
#
# Step 3
# The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate
# pre-compilation.
100
class TPUModelRunner(LoRAModelRunnerMixin):
101
102
103
104
105

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
106
        original_parallel_config: Optional[ParallelConfig] = None,
107
108
109
110
111
112
113
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
114
        self.original_parallel_config = original_parallel_config
115
116
117
118
119
120
121
122
123
124
125
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
        self.device_config = vllm_config.device_config

        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
        self.device = device
126
        self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
127

128
129
130
131
132
133
134
135
        # SPMD Related
        self.use_spmd = envs.VLLM_XLA_USE_SPMD
        if self.use_spmd:
            num_devices = xr.global_runtime_device_count()
            mesh_shape = (num_devices, 1)
            device_ids = np.array(range(num_devices))
            self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))

136
        self.enforce_eager = model_config.enforce_eager
137
138
139
140

        self.num_xla_graphs = 0
        self._update_num_xla_graphs("init")

141
142
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
143
        if cache_config.cache_dtype == "auto":
144
145
            model_dtype = self.dtype
            if isinstance(model_dtype, str):
146
                self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
147
148
            else:
                self.kv_cache_dtype = model_dtype
149
        else:
150
            self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
151
                cache_config.cache_dtype]
152
        self._hidden_states_dtype = self.dtype
153
154
155
156
157

        self.is_multimodal_model = model_config.is_multimodal_model
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
158
        self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
159
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
160
161
162
        self.num_blocks_per_most_len_req = cdiv(
            self.most_model_len,
            self.block_size) if self.most_model_len is not None else None
163
164
        # InputBatch needs to work with sampling tensors greater than padding
        # to avoid dynamic shapes. Also, avoid suboptimal alignment.
165
        self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
166
167
168
169
170
171
172
        self.num_tokens_paddings = _get_token_paddings(
            min_token_size=16,
            max_token_size=scheduler_config.max_num_batched_tokens,
            padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
        # In case `max_num_tokens < max(num_tokens_paddings)` use the actual
        # padded max value to pre-allocate data structures and pre-compile.
        self.max_num_tokens = self.num_tokens_paddings[-1]
173
174
175
176
177
178
179
180
181

        # Model-related.
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
        self.hidden_size = model_config.get_hidden_size()
182
        self.vocab_size = model_config.get_vocab_size()
183

184
185
186
        if self.lora_config is not None:
            self.vocab_size += self.lora_config.lora_extra_vocab_size

187
188
189
190
191
192
193
194
195
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.uses_mrope = model_config.uses_mrope
        # TODO: Support M-RoPE (e.g, Qwen2-VL)
        assert not self.uses_mrope, "TPU does not support M-RoPE yet."

        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
196
            mm_registry=self.mm_registry,
197
198
199
200
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size

201
202
203
204
205
206
207
208
        self._num_slices_per_kv_cache_update_block = \
            _get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
                block_size=self.block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                kv_cache_dtype=self.kv_cache_dtype,
            ))

209
        # Lazy initialization
210
        self.model: nn.Module  # Set after load_model
211
212
213
214
215
216
        self.kv_caches: list[torch.Tensor] = []
        # req_id -> (input_id -> encoder_output)
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

        # Request states.
        self.requests: dict[str, CachedRequestState] = {}
217

218
219
220
221
222
223
224
225
        # Initialize input batch early to avoid AttributeError in _update_states
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
            vocab_size=self.model_config.get_vocab_size(),
226
            block_sizes=[self.block_size],
227
228
        )

229
230
231
232
233
234
235
236
237
238
239
240
        # Cached torch/numpy tensor
        # The pytorch tensor and numpy array share the same buffer.
        # Sometimes the numpy op is faster so we create both.
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu")

        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu")
        self.positions_np = self.positions_cpu.numpy()
        self.block_table_cpu = torch.zeros(
241
            (self.max_num_reqs, self.max_num_blocks_per_req),
242
            dtype=torch.int32,
243
            device="cpu")
244
245
246
247
248
249
250
251
252
        # adjust num_reqs to avoid SMEM OOM.
        self.num_reqs_most_model_len = min(
            PallasAttentionBackend.get_max_num_seqs(self.most_model_len,
                                                    self.block_size),
            self.max_num_reqs) if self.most_model_len is not None else None
        self.num_reqs_max_model_len = min(
            PallasAttentionBackend.get_max_num_seqs(self.max_model_len,
                                                    self.block_size),
            self.max_num_reqs)
253
254
255
256
257
258
259
260
261
262
263
        self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()

        self.seq_lens_cpu = torch.zeros(self.max_num_tokens,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
264
265
266

        # Range tensor with values [0 .. self.max_num_tokens - 1].
        # Used to initialize positions / context_lens / seq_lens
267
268
        # Keep in int64 to avoid overflow with long context
        self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64)
269
270
        self.num_reqs_paddings = _get_req_paddings(
            min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
271

272
273
274
275
276
277
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}

278
279
280
281
282
283
284
285
286
287
288
289
290
291
        # tensors for structured decoding
        self.grammar_bitmask_cpu = torch.zeros(
            (self.max_num_reqs, cdiv(self.vocab_size, 32)),
            dtype=torch.int32,
            device="cpu",
            pin_memory=self.pin_memory)
        self.require_structured_out_cpu = torch.zeros(
            (self.max_num_reqs, 1),
            dtype=torch.bool,
            device="cpu",
            pin_memory=self.pin_memory)
        self.structured_decode_arange = torch.arange(
            0, 32, device="cpu", pin_memory=self.pin_memory)

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        # Get maximum number of mm items per modality (batch size).
        self.max_num_mm_items_by_modality = dict()
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
            max_tokens_by_modality_dict = (
                MULTIMODAL_REGISTRY.
                get_max_tokens_per_item_by_nonzero_modality(self.model_config))
            for modality, max_tokens in max_tokens_by_modality_dict.items():
                # Check how many items of this modality can be supported by
                # the encoder budget.
                encoder_budget = min(self.max_num_encoder_input_tokens,
                                     self.encoder_cache_size)

                max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                       max_tokens)

                # Check how many items of this modality can be supported by
                # the decoder budget.
                max_mm_items_per_req = self.mm_registry.\
                    get_mm_limits_per_prompt(self.model_config)[modality]

                # NOTE: We do not consider max_num_batched_tokens on purpose
                # because the multimodal embeddings can be generated in advance
                # and chunked prefilled.
                max_num_mm_items_decoder_budget = self.max_num_reqs * \
                    max_mm_items_per_req

                max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                       max_num_mm_items_decoder_budget)
                self.max_num_mm_items_by_modality[modality] = max_num_mm_items

323
324
325
326
327
328
329
330
331
        if not self.use_spmd:
            self.sample_from_logits_func = torch.compile(
                self.sample_from_logits,
                backend="openxla",
                fullgraph=True,
                dynamic=False)
        else:
            self.sample_from_logits_func = self.sample_from_logits

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    def _update_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        total_cached_graphs = xr.get_num_cached_compilation_graph()
        new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
        if new_compiled_graphs == 0:
            return

        logger.info("Add new %d compiled XLA graphs due to %s",
                    new_compiled_graphs, case_str)
        self.num_xla_graphs += new_compiled_graphs

    def _verify_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        curr_cached_graph = xr.get_num_cached_compilation_graph()
        assert self.num_xla_graphs == curr_cached_graph, (
            "Recompilation after warm up is detected during {}."
            " num_xla_graphs = {} curr_cached_graph = {}".format(
                case_str, self.num_xla_graphs, curr_cached_graph))

357
358
359
360
361
362
363
364
    def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        Returns:
365
            True if there is a new/resumed/paused/finished request.
366
367
368
369
370
            If False, we can skip copying SamplingMetadata to the GPU.
        """
        # Remove finished requests from the cached states.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
371
            self.encoder_cache.pop(req_id, None)
372
373
374
375
376
377
378

        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
379
        removed_req_indices: list[int] = []
380
381
382
383
384
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)

385
386
387
388
389
390
391
392
        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            assert req_index is not None
            removed_req_indices.append(req_index)

410
        req_ids_to_add: list[str] = []
411
412
        # Add new requests to the cached states.
        for new_req_data in scheduler_output.scheduled_new_reqs:
413
414
            assert new_req_data.sampling_params is not None,\
                "Pooling is not supported in TPU yet"
415
416
417
418
419
420
421
422
423
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params

            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
                prompt_token_ids=new_req_data.prompt_token_ids,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
                sampling_params=sampling_params,
424
                pooling_params=None,
425
                generator=None,
426
427
428
429
430
431
432
433
434
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                output_token_ids=[],
                lora_request=new_req_data.lora_request,
            )

            req_ids_to_add.append(req_id)

        # Update the states of the running/resumed requests.
435
436
        req_data = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(req_data.req_ids):
437
            req_state = self.requests[req_id]
438
439
440
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
            resumed_from_preemption = req_data.resumed_from_preemption[i]
441
442

            # Update the cached states.
443
444
            req_state.num_computed_tokens = num_computed_tokens
            if not resumed_from_preemption:
445
                # Append the new blocks to the existing block IDs.
446
447
448
                for block_ids, new_ids in zip(req_state.block_ids,
                                              new_block_ids):
                    block_ids.extend(new_ids)
449
450
451
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
452
                req_state.block_ids = new_block_ids
453
454
455
456
457
458
459
460
461
462
463

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
464
465
                num_computed_tokens)
            self.input_batch.block_table.append_row(new_block_ids, req_index)
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
483

484
485
486
487
488
        return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0

    def get_model(self) -> nn.Module:
        return self.model

489
490
491
492
493
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        model = self.get_model()
        if not is_pooling_model(model):
            return []

494
        return list(model.pooler.get_supported_tasks())
495

496
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
497
        """
498
        Generates the KVCacheSpec by parsing the kv cache format from each
499
500
        Attention module in the static forward context.
        Returns:
501
            KVCacheSpec: A dictionary mapping layer names to their KV cache
502
503
504
            format. Layers that do not need KV cache are not included.
        """

505
        layers = get_layers_from_vllm_config(self.vllm_config, Attention)
506
        block_size = self.vllm_config.cache_config.block_size
507
        kv_cache_spec: dict[str, KVCacheSpec] = {}
508
        for layer_name, attn_module in layers.items():
509
510
511
512
513
514
515
516
517
518
519
520
            if (kv_tgt_layer :=
                    attn_module.kv_sharing_target_layer_name) is not None:
                # The layer doesn't need its own KV cache and will use that of
                # the target layer. We skip creating a KVCacheSpec for it, so
                # that KV cache management logic will act as this layer does
                # not exist, and doesn't allocate KV cache for the layer. This
                # enables the memory saving of cross-layer kv sharing, allowing
                # a given amount of memory to accommodate longer context lengths
                # or enable more requests to be processed simultaneously.
                self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                continue

521
            if attn_module.attn_type == AttentionType.DECODER:
522
523
524
525
526
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
527
                        dtype=self.kv_cache_dtype,
528
529
530
531
532
533
534
535
                        sliding_window=attn_module.sliding_window,
                        use_mla=False,
                    )
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
536
                        dtype=self.kv_cache_dtype,
537
538
                        use_mla=False,
                    )
539
540
541
542
543
544
545
546
547
548
549
550
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    def _get_slot_mapping_metadata(self, num_reqs,
                                   num_scheduled_tokens_per_req):
        """
        Computes metadata for mapping slots to blocks in the key-value (KV)
        cache for a batch of requests.

        This function determines, for each request in the batch, how the
        scheduled tokens are distributed across memory blocks, and generates
        metadata needed to map slices of tokens to their corresponding positions
        in the KV cache.

        Args:
            num_reqs (int): Number of requests in the current batch.
            num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
            to be scheduled for each request.

        Returns:
            np.ndarray: A 2D array of shape (total_block_len, 3), where each row
            contains:
                - kv_cache_start_index (int): The starting index in the KV cache
                    for the corresponding slice.
                - new_kv_start_index (int): The starting index in the new KV
                    cache for the corresponding slice.
                - slice_len (int): The length of the slice.
        """
        slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
        slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \
            num_scheduled_tokens_per_req
        local_block_start_idx = slices_start // self.block_size
        local_block_end_idx = (slices_end - 1) // self.block_size
        no_repeat_req_indices = self.arange_np[:num_reqs]
        global_block_start_idx = (
            no_repeat_req_indices * self.max_num_blocks_per_req +
            local_block_start_idx)
        block_lens = local_block_end_idx - local_block_start_idx + 1
        global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
        slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
        global_block_indices = global_block_start_idx + slice_arange
        block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
        total_block_len = np.sum(block_lens)
        slot_mapping_slices = np.repeat(np.array([[0, self.block_size]],
                                                 dtype=np.int32),
                                        total_block_len,
                                        axis=0)
        cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
        np.cumsum(block_lens, out=cu_block_lens[1:])
        for req_idx in range(num_reqs):
            slot_mapping_slices[cu_block_lens[req_idx]][
                0] = slices_start[req_idx] % self.block_size
            slot_mapping_slices[
                cu_block_lens[req_idx + 1] -
                1][1] = (slices_end[req_idx] - 1) % self.block_size + 1
        slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
        cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
        np.cumsum(slice_lens, out=cu_slices_lens[1:])
        kv_cache_start_indices = slot_mapping_slices[:, 0] + \
            (block_numbers * self.block_size)
        new_kv_start_indices = cu_slices_lens[:-1]
        slot_mapping_metadata = np.stack(
            [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1)
        return slot_mapping_metadata

614
615
616
    def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
                        start_index: int):
        assert scheduler_output.total_num_scheduled_tokens > 0
617
618
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0
619
        assert start_index < num_reqs
620

621
        # Get the number of scheduled tokens for each request.
622
        use_max_model_len = self.most_model_len is None
623
624
        num_scheduled_tokens_per_req = []
        max_num_scheduled_tokens_all_reqs = 0
625
626
627
628
629
        end_index = start_index

        # Use either most_model_len or max_model_len depending on request size.
        for i in range(start_index, num_reqs):
            req_id = self.input_batch.req_ids[i]
630
            assert req_id is not None
631
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
632
633
            if not use_max_model_len and num_tokens > self.most_model_len:
                use_max_model_len = True
634
            num_scheduled_tokens_per_req.append(num_tokens)
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        if use_max_model_len:
            if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
                num_scheduled_tokens_per_req = \
                    num_scheduled_tokens_per_req[:self.num_reqs_max_model_len]
                end_index = start_index + self.num_reqs_max_model_len
            else:
                end_index = num_reqs
        else:
            if len(num_scheduled_tokens_per_req
                   ) > self.num_reqs_most_model_len:
                num_scheduled_tokens_per_req = \
                    num_scheduled_tokens_per_req[:self.num_reqs_most_model_len]
                end_index = start_index + self.num_reqs_most_model_len
            else:
                end_index = num_reqs
        max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
651
652
        num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
                                                dtype=np.int32)
653
        total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
654
655
        assert max_num_scheduled_tokens_all_reqs > 0

656
657
        num_reqs = len(num_scheduled_tokens_per_req)

658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        # For each scheduled token, what are the corresponding req index.
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens_per_req)

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # For each scheduled token, what is its position in corresponding req.
        arange = np.concatenate(
            [self.arange_np[:n] for n in num_scheduled_tokens_per_req])

        # Get positions.
        positions_np = self.positions_np[:total_num_scheduled_tokens]
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])

        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
686
687
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
                           0,
688
689
690
691
692
693
694
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])

        # Prepare the attention metadata.
        self.query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens_per_req,
                  out=self.query_start_loc_np[1:num_reqs + 1])
695
        self.query_start_loc_np[num_reqs + 1:] = 1
696
697
698
699
700
701

        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens_per_req)

        # Do the padding and copy the tensors to the TPU.
702
        padded_total_num_scheduled_tokens = _get_padded_token_len(
703
            self.num_tokens_paddings, total_num_scheduled_tokens)
704
705
706
        # Zero out to avoid spurious values from prev iteration (last cp chunk)
        self.input_ids_cpu[
            total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
707
708
709
710
711
712
        self.input_ids = self.input_ids_cpu[:
                                            padded_total_num_scheduled_tokens].to(
                                                self.device)
        self.position_ids = self.positions_cpu[:
                                               padded_total_num_scheduled_tokens].to(
                                                   self.device)
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        if use_max_model_len:
            block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
                                                self.max_num_blocks_per_req]
            block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
                self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
            query_start_loc = self.query_start_loc_cpu[:self.
                                                       num_reqs_max_model_len +
                                                       1].to(self.device)
            seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to(
                self.device)
        else:
            block_tables = self.block_table_cpu[:self.
                                                num_reqs_most_model_len, :self.
                                                num_blocks_per_most_len_req]
            block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = (
                self.input_batch.block_table[0].get_cpu_tensor()
                [:num_reqs, :self.num_blocks_per_most_len_req])
            query_start_loc = self.query_start_loc_cpu[:self.
                                                       num_reqs_most_model_len +
                                                       1].to(self.device)
            seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to(
                self.device)
735
        block_tables = block_tables.to(self.device)
736

737
        # Calculate the slot mapping
738
739
        slot_mapping_metadata = self._get_slot_mapping_metadata(
            num_reqs, num_scheduled_tokens_per_req)
740
        num_kv_update_slices = slot_mapping_metadata.shape[0]
741
742
        padded_num_slices = _get_padded_num_kv_cache_update_slices(
            padded_total_num_scheduled_tokens, self.max_num_reqs,
743
            self.block_size, self._num_slices_per_kv_cache_update_block)
744
745
746
747
748
749
750
751
        slot_mapping_metadata = np.pad(
            slot_mapping_metadata,
            [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
            constant_values=0)
        slot_mapping_metadata = np.transpose(slot_mapping_metadata)
        slot_mapping_metadata = torch.tensor(slot_mapping_metadata,
                                             device=self.device)

752
753
754
755
756
757
758
759
760
761
762
        if self.lora_config is not None:
            # We need to respect padding when activating LoRA adapters
            padded_num_scheduled_tokens_per_req = np.copy(
                num_scheduled_tokens_per_req
            )  # Copying to avoid accidental state corruption bugs
            padded_num_scheduled_tokens_per_req[-1] += \
                padded_total_num_scheduled_tokens - total_num_scheduled_tokens

            self.set_active_loras(self.input_batch,
                                  padded_num_scheduled_tokens_per_req)

763
        attn_metadata = PallasMetadata(
764
            slot_mapping=slot_mapping_metadata,
765
            block_tables=block_tables,
766
767
            context_lens=seq_lens,
            query_start_loc=query_start_loc,
768
769
770
            num_seqs=torch.tensor([num_reqs],
                                  dtype=torch.int32,
                                  device=self.device),
771
772
773
            num_kv_update_slices=torch.tensor([num_kv_update_slices],
                                              dtype=torch.int32,
                                              device=self.device),
774
775
            num_slices_per_kv_cache_update_block=self.
            _num_slices_per_kv_cache_update_block,
776
        )
777
778
779
780
781
        # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
        # request in the batch. While we should not sample any token from this
        # partial request, we do so for simplicity. We will ignore the sampled
        # token from the partial request.
        # TODO: Support prompt logprobs.
782
783
        padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
            num_reqs, self.max_num_reqs)
784
785
        # Indices at which we sample (positions of last token in the sequence).
        # Padded to avoid recompiling when `num_reqs` varies.
786
787
        logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
        logits_indices = logits_indices.to(self.device)
788

789
790
791
792
793
794
795
796
797
798
799
        if self.lora_config is not None:
            # We need to respect padding when activating LoRA adapters
            padded_num_scheduled_tokens_per_req = np.copy(
                num_scheduled_tokens_per_req
            )  # Copying to avoid accidental state corruption bugs
            padded_num_scheduled_tokens_per_req[-1] += \
                padded_total_num_scheduled_tokens - total_num_scheduled_tokens

            self.set_active_loras(self.input_batch,
                                  padded_num_scheduled_tokens_per_req)

800
801
802
803
804
805
        layer_names = get_layers_from_vllm_config(self.vllm_config,
                                                  Attention).keys()
        per_layer_attn_metadata = {
            layer_name: attn_metadata
            for layer_name in layer_names
        }
806
807
        return per_layer_attn_metadata, logits_indices, padded_num_reqs,\
            num_reqs, end_index
808

809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
    def _scatter_placeholders(
        self,
        embeds: torch.Tensor,
        is_embed: Optional[torch.Tensor],
    ) -> torch.Tensor:
        if is_embed is None:
            return embeds

        placeholders = embeds.new_full(
            (is_embed.shape[0], embeds.shape[-1]),
            fill_value=torch.nan,
        )
        placeholders[is_embed] = embeds
        return placeholders

    def _gather_placeholders(
        self,
        placeholders: torch.Tensor,
        is_embed: Optional[torch.Tensor],
    ) -> torch.Tensor:
        if is_embed is None:
            return placeholders

        return placeholders[is_embed]

    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
835
836
837
838
839
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
840
841
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
842
843
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
844
845
846
847
848

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
849
850
851
852
853
854
855
856
857
858
859
860
861

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
862
863
864
865
            batched_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_mm_inputs,
                device=self.device,
            )
866
867
868
869
870
871
872
873

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
874
            xm.mark_step()
875
876
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)
877
            xm.mark_step()
878

879
880
881
882
883
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

884
885
886
887
888
889
            if isinstance(curr_group_outputs, torch.Tensor):
                encoder_outputs.append(curr_group_outputs)
            else:
                assert isinstance(curr_group_outputs, (list, tuple))
                for output in curr_group_outputs:
                    encoder_outputs.append(output)
890
891

        # Cache the encoder outputs.
892
893
894
        # NOTE (NickLucche) here we diverge from logic in other runners, as we
        # assume to only have whole mm items to process. Hence we avoid the
        # intrinsic dynamism that `scatter_mm_placeholders` introduces.
895
896
897
898
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
899
900
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}
901
902
903
            assert pos_info.is_embed is None, "Expected all positions to be"\
                " contiguous and embeddings."
            self.encoder_cache[req_id][input_id] = output
904
905

    def _gather_mm_embeddings(
906
907
908
        self,
        scheduler_output: "SchedulerOutput",
    ) -> list[torch.Tensor]:
909
        mm_embeds: list[torch.Tensor] = []
910
911
912
913
914
915
        for req_id in self.input_batch.req_ids:
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
916
917
918
919
            # TODO unroll loop and assume/enforce --disable_chunked_mm_input
            # NOTE (NickLucche) here we diverge from logic in other runners, as
            # we assume to only have whole mm items to process. Hence we avoid
            # the intrinsic dynamism that `gather_mm_placeholders` introduces.
920
            for i, pos_info in enumerate(mm_positions):
921
922
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
938
939
                assert pos_info.is_embed is None, "Expected all positions to"\
                " be contiguous and embeddings."
940
                encoder_output = self.encoder_cache[req_id][i]
941
                mm_embeds.append(encoder_output)
942
        return mm_embeds
943

944
945
946
947
948
949
    def _get_model_inputs(self, input_ids: torch.Tensor,
                          mm_embeds: list[torch.Tensor]):
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
950
951
952
953
            inputs_embeds = self.model.get_input_embeddings(
                input_ids=input_ids,
                multimodal_embeddings=mm_embeds,
            )
954
955
956
957
958
959
960
961
            return None, inputs_embeds
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            return input_ids, None

962
963
964
965
    @torch.no_grad()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
966
        intermediate_tensors: Optional[IntermediateTensors] = None,
967
968
969
    ) -> ModelRunnerOutput:
        # Update cached state
        self._update_states(scheduler_output)
970
        if not scheduler_output.total_num_scheduled_tokens:
971
            # Return empty ModelRunnerOutput if there's no work to do.
972
            return EMPTY_MODEL_RUNNER_OUTPUT
973

974
975
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
976
977
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
978
        else:
979
            mm_embeds = []
980
        xm.mark_step()
981
        # Prepare inputs, the requests might be split into multiple
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        # executions, combine the result of each execution.
        start_index = 0
        combined_selected_tokens: list[torch.Tensor] = []
        combined_logprobs: list[LogprobsLists] = []
        while start_index < self.input_batch.num_reqs:
            attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
                end_index = self._prepare_inputs(scheduler_output, start_index)
            input_ids, inputs_embeds = self._get_model_inputs(
                self.input_ids, mm_embeds)
            xm.mark_step()
            # Run the decoder
            with set_forward_context(
                    attn_metadata,
                    self.vllm_config,
                    num_tokens=scheduler_output.total_num_scheduled_tokens):
                hidden_states = self.model(
                    input_ids=input_ids,
                    positions=self.position_ids,
                    inputs_embeds=inputs_embeds,
                )
            hidden_states = self.select_hidden_states(hidden_states,
                                                      logits_indices)
            logits = self.compute_logits(hidden_states)
            tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
                from_input_batch(self.input_batch, padded_num_reqs, self.device)
            if scheduler_output.grammar_bitmask is not None:
                require_struct_decoding, grammar_bitmask_padded, arange = \
                    self.prepare_structured_decoding_input(logits,
                                                           scheduler_output)
                logits = self.structured_decode(require_struct_decoding,
                                                grammar_bitmask_padded, logits,
                                                arange)
            selected_token_ids = self.sample_from_logits_func(
                logits, tpu_sampling_metadata)
            # NOTE (NickLucche) Use the original logits (before any penalties or
            # temperature scaling) for the top-k logprobs. We can't enforce it
            # due to recompilations outside torch.compiled code, so just make
            # sure `sample_from_logits` does not modify the logits in-place.
            logprobs = self.gather_logprobs(logits, selected_token_ids) \
                if tpu_sampling_metadata.logprobs else None

            # Remove padding on cpu and keep dynamic op outside of xla graph.
            selected_token_ids = selected_token_ids.cpu()[:num_reqs]

            combined_selected_tokens.append(selected_token_ids)
            if tpu_sampling_metadata.logprobs:
                combined_logprobs.append(logprobs.tolists())

            start_index = end_index

        selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
        if tpu_sampling_metadata.logprobs:

            def concat_lists(input_lists):
                result = []
                for input_list in input_lists:
                    result.extend(input_list)
                return result

            logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists(
                [lp.logprob_token_ids for lp in combined_logprobs]),
                                           logprobs=concat_lists([
                                               lp.logprobs
                                               for lp in combined_logprobs
                                           ]),
                                           sampled_token_ranks=concat_lists([
                                               lp.sampled_token_ranks
                                               for lp in combined_logprobs
                                           ]))
        else:
            logprobs_lists = None
1053

1054
1055
        # Update the cache state concurrently. Code above will not block until
        # we use `selected_token_ids`. Add mark_step if post-processing changes
1056
        request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
1057
        discard_sampled_tokens_req_indices = []
1058
        num_reqs = self.input_batch.num_reqs
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
        for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
            assert req_id is not None
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
            if seq_len >= req_state.num_tokens:
                request_seq_lens.append((i, req_state, seq_len))
            else:
                # Ignore the sampled token from the partial request.
                # Rewind the generator state as if the token was not sampled.
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)

1074
1075
1076
1077
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

1078
1079
1080
        assert all(
            req_id is not None for req_id in
            self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
1081
        req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
1082

1083
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
1084
        for req_id in self.input_batch.req_ids[:num_reqs]:
1085
1086
            prompt_logprobs_dict[req_id] = None

1087
1088
1089
        max_gen_len = selected_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_sampled_token_ids = selected_token_ids.tolist()
1090

1091
1092
1093
1094
1095
1096
1097
            # Mask out the sampled tokens that should not be sampled.
            # TODO: Keep in sync with gpu_model_runner.py, in particular
            #       the "else" case here
            for i in discard_sampled_tokens_req_indices:
                valid_sampled_token_ids[i].clear()

            # Append sampled tokens
1098
1099
1100
1101
1102
            for i, req_state, seq_len in request_seq_lens:
                token_id = valid_sampled_token_ids[i][0]
                self.input_batch.token_ids_cpu[i, seq_len] = token_id
                req_state.output_token_ids.append(token_id)
                self.input_batch.num_tokens[i] += 1
1103

1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        else:
            valid_mask = selected_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
            valid_sampled_token_ids = [
                seq.tolist()
                for seq in selected_token_ids[valid_mask].split(gen_lens)
            ]
            self.input_batch.num_tokens[:num_reqs] += gen_lens
            for i, req_state, seq_len in request_seq_lens:
                target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
                self.input_batch.token_ids_cpu[
                    i, target_slice] = valid_sampled_token_ids[i]
                req_state.output_token_ids.extend(valid_sampled_token_ids[i])

1118
        model_runner_output = ModelRunnerOutput(
1119
            req_ids=req_ids,
1120
            req_id_to_index=self.input_batch.req_id_to_index,
1121
            sampled_token_ids=valid_sampled_token_ids,
1122
            spec_token_ids=None,
1123
            logprobs=logprobs_lists,
1124
            prompt_logprobs_dict=prompt_logprobs_dict,
1125
            pooler_output=[],
1126
        )
1127
1128
1129
1130
1131

        # Check there are no new graphs compiled - all the graphs should be
        # captured and compiled during warm up.
        self._verify_num_xla_graphs("execute_model")

1132
1133
        return model_runner_output

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    def update_config(self, overrides: dict[str, Any]) -> None:
        # TODO: TPU config may need extra validation
        # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
        allowed_config_names = {"load_config", "model_config"}
        for config_name, config_overrides in overrides.items():
            assert config_name in allowed_config_names, \
                f"Config `{config_name}` not supported. " \
                f"Allowed configs: {allowed_config_names}"
            config = getattr(self, config_name)
            new_config = update_config(config, config_overrides)
            setattr(self, config_name, new_config)

1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
    def load_model(self) -> None:
        self.device = self.device_config.device

        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
        xm_tp_rank = xr.global_ordinal()
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
1163
1164
1165
1166
1167
            try:
                if self.use_spmd:
                    tpu_loader = TPUModelLoader(
                        load_config=self.vllm_config.load_config)
                    model = tpu_loader.load_model(
1168
                        vllm_config=self.vllm_config,
1169
1170
                        model_config=self.vllm_config.model_config,
                        mesh=self.mesh)
1171
                else:
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
                    model_loader = get_model_loader(self.load_config)
                    if not hasattr(self, "model"):
                        logger.info("Loading model from scratch...")
                        model = model_loader.load_model(
                            vllm_config=self.vllm_config,
                            model_config=self.model_config)
                    else:
                        logger.info("Model was already initialized. \
                                Loading weights inplace...")
                        model_loader.load_weights(
                            self.model, model_config=self.model_config)
            except RuntimeError as e:
                raise RuntimeError(
                    f"Unable to load model, a likely reason is the model is "
                    "too large for the current device's HBM memory. "
                    "Consider switching to a smaller model "
                    "or sharding the weights on more chips. "
                    f"See the detailed error: {e}") from e
1190
1191
1192
1193
        if self.lora_config is not None:
            model = self.load_lora_model(model, self.model_config,
                                         self.scheduler_config,
                                         self.lora_config, self.device)
1194
            replace_set_lora(model)
1195

1196
1197
        # Sync all pending XLA execution during model initialization and weight
        # loading.
1198
1199
        xm.mark_step()
        xm.wait_device_ops()
1200
1201
        if not hasattr(self, "model"):
            self.model = model
1202
        self.sampler = TPUSampler()
1203

1204
    @torch.no_grad()
1205
1206
    def _dummy_run(self, num_tokens: int, num_reqs: int,
                   num_blocks: int) -> None:
1207
1208
1209
1210
1211
1212
1213
        if self.is_multimodal_model:
            input_ids = None
            inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
                                        dtype=self.dtype,
                                        device=self.device)
        else:
            input_ids = torch.zeros((num_tokens),
1214
                                    dtype=torch.int32).to(self.device)
1215
            inputs_embeds = None
1216
        actual_num_reqs = min(num_tokens, num_reqs)
1217
        position_ids = torch.zeros(num_tokens,
1218
                                   dtype=torch.int32).to(self.device)
1219
        padded_num_slices = _get_padded_num_kv_cache_update_slices(
1220
1221
            num_tokens, self.max_num_reqs, self.block_size,
            self._num_slices_per_kv_cache_update_block)
1222
1223
        num_kv_update_slices = torch.tensor([padded_num_slices],
                                            dtype=torch.int32).to(self.device)
1224
1225
        slot_mapping = torch.zeros((3, padded_num_slices),
                                   dtype=torch.int32).to(self.device)
1226
1227
1228
        block_tables = torch.zeros((num_reqs, num_blocks),
                                   dtype=torch.int32).to(self.device)
        query_lens = [1] * num_reqs
1229
1230
1231
1232
        query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
                                                    dtype=torch.int32),
                                       dim=0,
                                       dtype=torch.int32).to(self.device)
1233
        context_lens = torch.ones((num_reqs, ),
1234
                                  dtype=torch.int32).to(self.device)
1235
        num_seqs = torch.tensor([actual_num_reqs],
1236
                                dtype=torch.int32).to(self.device)
1237
1238
1239
1240
1241
        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            context_lens=context_lens,
            query_start_loc=query_start_loc,
1242
            num_seqs=num_seqs,
1243
            num_kv_update_slices=num_kv_update_slices,
1244
1245
            num_slices_per_kv_cache_update_block=self.
            _num_slices_per_kv_cache_update_block,
1246
        )
1247

1248
1249
1250
1251
        if self.is_multimodal_model:
            torch._dynamo.mark_dynamic(inputs_embeds, 0)
        else:
            torch._dynamo.mark_dynamic(input_ids, 0)
1252
1253
        torch._dynamo.mark_dynamic(position_ids, 0)
        torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
1254
1255
1256
        torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
        torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
        torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
1257

1258
1259
1260
1261
1262
1263
1264
        layer_names = get_layers_from_vllm_config(self.vllm_config,
                                                  Attention).keys()
        per_layer_attn_metadata = {
            layer_name: attn_metadata
            for layer_name in layer_names
        }

1265
        with self.maybe_select_dummy_loras(
1266
1267
1268
                self.lora_config,
                np.array([num_tokens], dtype=np.int32)), set_forward_context(
                    per_layer_attn_metadata, self.vllm_config, 0):
1269
1270
1271
1272
            out = self.model(input_ids=input_ids,
                             positions=position_ids,
                             inputs_embeds=inputs_embeds)
        self._hidden_states_dtype = out.dtype
1273

1274
1275
1276
1277
1278
1279
1280
    def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
                          lora_requests) -> None:
        xm.mark_step()  # Captures input updates
        super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
                                  lora_requests)
        xm.mark_step()  # Captures metadata updates

1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
    def _precompile_mm_encoder(self) -> None:
        # Pre-compile MM encoder for all supported data modalities.
        hf_config = self.vllm_config.model_config.hf_config
        for mode, max_items_by_mode in \
            self.max_num_mm_items_by_modality.items():
            logger.info(
                "Compiling Multimodal %s Encoder with different input"
                " shapes.", mode)
            start = time.perf_counter()
            # No padding for MM encoder just yet.
            for num_items in range(1, max_items_by_mode + 1):
                logger.info("  -- mode: %s items: %d", mode, num_items)
                batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                    mode, num_items)
                # Run multimodal encoder.
                xm.mark_step()
                mm_embeds = self.model.\
                    get_multimodal_embeddings(**batched_dummy_mm_inputs)
                xm.mark_step()
                num_patches = mm_embeds[0].shape[0]
                items_size = num_patches * num_items

                # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm
                # embeddings are present. We assume `--disable-mm-chunked`,
                # hence only whole items can be scheduled. This implies we just
                # need to compile when `num_items` fit the (padded) `input_ids`
                for num_tokens in self.num_tokens_paddings:
                    if num_tokens >= items_size:
                        # XLA Workaround: if torch.zeros(..device) is used, XLA
                        # compiles a scalar+expansion op, which won't match
                        # the graph generated at runtime. CPU->TPU must be used
                        placeholders_ids = torch.zeros(num_tokens,
                                                       dtype=torch.int32,
                                                       device="cpu")
                        # Align placeholders and actual num mm_embeddings.
                        placeholders_ids[:items_size] = \
                            hf_config.image_token_index

                        placeholders_ids = placeholders_ids.to(self.device)
                        # Assign outputs or the graph will be cut short.
                        a, b = self._get_model_inputs(placeholders_ids,
                                                      [mm_embeds])
                        assert a is None
                        xm.mark_step()

            # Pre-compile `get_input_embeddings` when mm_embeddings are not
            # present. Chunk is only made of text, no mm_placeholders.
            for num_tokens in self.num_tokens_paddings:
                placeholders_ids = torch.zeros(num_tokens,
                                               dtype=torch.int32,
                                               device="cpu")
                placeholders_ids = placeholders_ids.to(self.device)
                a, b = self._get_model_inputs(placeholders_ids, [])
                assert a is None
                xm.mark_step()

            xm.wait_device_ops()
            end = time.perf_counter()
            logger.info(
                "Multimodal %s Encoder compilation finished in in %.2f "
                "[secs].", mode, end - start)

1343
    def _precompile_backbone(self) -> None:
1344
1345
        logger.info("Compiling the model with different input shapes.")
        start = time.perf_counter()
1346
        for num_tokens in self.num_tokens_paddings:
1347
            logger.info("  -- num_tokens: %d", num_tokens)
1348
1349
1350
1351
1352
            self._dummy_run(num_tokens, self.num_reqs_max_model_len,
                            self.max_num_blocks_per_req)
            if self.most_model_len is not None:
                self._dummy_run(num_tokens, self.num_reqs_most_model_len,
                                self.num_blocks_per_most_len_req)
1353
1354
        xm.wait_device_ops()
        end = time.perf_counter()
1355
        logger.info("Compilation finished in %.2f [secs].", end - start)
1356
        self._update_num_xla_graphs("model backbone")
1357

1358
1359
1360
1361
1362
    def _precompile_select_hidden_states(self) -> None:
        # Compile hidden state selection function for bucketed
        # n_tokens x max_num_reqs. Graph is really small so this is fine.
        logger.info(
            "Compiling select_hidden_states with different input shapes.")
1363
1364
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
1365
        for num_tokens in self.num_tokens_paddings:
1366
1367
            dummy_hidden = torch.zeros((num_tokens, hsize),
                                       device=self.device,
1368
                                       dtype=self._hidden_states_dtype)
1369
1370
1371
1372
1373
1374
1375
            torch._dynamo.mark_dynamic(dummy_hidden, 0)
            for num_reqs in self.num_reqs_paddings:
                indices = torch.zeros(num_reqs,
                                      dtype=torch.int32,
                                      device=self.device)
                torch._dynamo.mark_dynamic(indices, 0)
                self.select_hidden_states(dummy_hidden, indices)
1376
1377
1378
1379
1380
1381
                logger.info("  -- num_tokens: %d, num_seqs: %d", num_tokens,
                            num_reqs)
                # Requests can't be more than tokens. But do compile for the
                # next bigger value in case num_tokens uses bucketed padding.
                if num_reqs >= min(num_tokens, self.max_num_reqs):
                    break
1382
        xm.wait_device_ops()
1383
        end = time.perf_counter()
1384
        logger.info("Compilation finished in %.2f [secs].", end - start)
1385
        self._update_num_xla_graphs("select_hidden_states")
1386

1387
1388
    def _precompile_compute_logits(self) -> None:
        logger.info("Compiling compute_logits with different input shapes.")
1389
1390
1391
1392
1393
1394
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
        for num_reqs in self.num_reqs_paddings:
            dummy_hidden = torch.zeros((num_reqs, hsize),
                                       device=self.device,
                                       dtype=self._hidden_states_dtype)
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
            torch._dynamo.mark_dynamic(dummy_hidden, 0)
            self.compute_logits(dummy_hidden)
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("compute_logits")

    def _precompile_structured_decoding(self) -> None:
        logger.info(
            "Compiling structured_decoding with different input shapes.")
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
            dummy_logits = torch.zeros((num_reqs, self.vocab_size),
                                       device=self.device,
                                       dtype=self._hidden_states_dtype)
            dummy_require_struct_decoding = \
                self.require_structured_out_cpu[:num_reqs].to(self.device)
            dummy_grammar_bitmask = \
                self.grammar_bitmask_cpu[:num_reqs].to(self.device)
            # The first dimension of the above 3 dummy tensors cannot be
            # mark_dynamic because some operations in structured_decode require
            # them to be static.
            arange = self.structured_decode_arange.to(self.device)
            self.structured_decode(dummy_require_struct_decoding,
                                   dummy_grammar_bitmask, dummy_logits, arange)
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("structured_decoding")

    def _precompile_sample_from_logits(self) -> None:
        logger.info(
            "Compiling sample_from_logits with different input shapes.")
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
            dummy_logits = torch.zeros((num_reqs, self.vocab_size),
                                       device=self.device,
                                       dtype=self._hidden_states_dtype)
            # The first dimension of dummy_logits cannot be mark_dynamic
            # because some operations in the sampler require it to be static.
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
            for all_greedy in [False, True]:
                generate_params_if_all_greedy = not all_greedy
                sampling_metadata = (
                    TPUSupportedSamplingMetadata.from_input_batch(
                        self.input_batch,
                        num_reqs,
                        self.device,
                        generate_params_if_all_greedy,
                    ))
                sampling_metadata.all_greedy = all_greedy
1447
1448
1449
                with self.maybe_select_dummy_loras(
                        self.lora_config, np.array([num_reqs],
                                                   dtype=np.int32)):
1450
1451
                    self.sample_from_logits_func(dummy_logits,
                                                 sampling_metadata)
1452
1453
1454
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
1455
1456
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("sample_from_logits")
1457

1458
1459
1460
1461
1462
1463
1464
1465
1466
    def _precompile_gather_logprobs(self) -> None:
        logger.info("Compiling gather_logprobs with different input shapes.")
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
            dummy_logits = torch.zeros((num_reqs, self.vocab_size),
                                       device=self.device,
                                       dtype=self._hidden_states_dtype)
            dummy_tokens = torch.zeros((num_reqs, 1),
                                       dtype=torch.int64).to(self.device)
1467
1468
1469
            with self.maybe_select_dummy_loras(
                    self.lora_config, np.array([num_reqs], dtype=np.int32)):
                self.gather_logprobs(dummy_logits, dummy_tokens)
1470
1471
1472
1473
1474
1475
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("gather_logprobs")

1476
1477
1478
1479
    def capture_model(self) -> None:
        """
        Precompile all the subgraphs with possible input shapes.
        """
1480
1481
1482
1483
1484
1485
1486
1487
        with self.maybe_setup_dummy_loras(self.lora_config):
            self._precompile_mm_encoder()
            self._precompile_backbone()
            self._precompile_select_hidden_states()
            self._precompile_compute_logits()
            self._precompile_structured_decoding()
            self._precompile_sample_from_logits()
            self._precompile_gather_logprobs()
1488

1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
    def profile_run(
        self,
        num_tokens: int,
    ) -> None:
        # Profile with multimodal encoder & encoder cache.
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):

            # NOTE: Currently model is profiled with a single non-text
            # modality with the max possible input tokens even when
            # it supports multiple.
            dummy_data_modality, max_num_mm_items = max(
                self.max_num_mm_items_by_modality.items(), key=lambda t: t[1])

            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            logger.info(
                "Encoder cache will be initialized with a budget of %d tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
            batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                dummy_data_modality, max_num_mm_items)

            # Run multimodal encoder.
            # Isolate encoder graph from post-processing to minimize
            # impact of recompilation until it's fixed.
            start = time.perf_counter()
            xm.mark_step()
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
            xm.mark_step()
            xm.wait_device_ops()
            end = time.perf_counter()
            logger.info(
                "Multimodal Encoder profiling finished in in %.2f [secs].",
                end - start)

            assert len(dummy_encoder_outputs) == max_num_mm_items, (
                "Expected dimension 0 of encoder outputs to match the number "
                f"of multimodal data items: {max_num_mm_items}, got "
                f"{len(dummy_encoder_outputs)=} instead. This is most likely "
                "due to the 'get_multimodal_embeddings' method of the model "
                "not implemented correctly.")

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

        # Trigger compilation for general shape.
1541
1542
1543
1544
1545
        self._dummy_run(num_tokens, self.num_reqs_max_model_len,
                        self.max_num_blocks_per_req)
        if self.most_model_len is not None:
            self._dummy_run(num_tokens, self.num_reqs_most_model_len,
                            self.num_blocks_per_most_len_req)
1546
1547
1548
1549
1550
1551

        xm.mark_step()
        xm.wait_device_ops()
        self.encoder_cache.clear()
        gc.collect()

1552
1553
1554
1555
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1556
            kv_cache_config: Configuration for the KV cache, including the KV
1557
1558
            cache size of each layer
        """
1559
        if len(kv_cache_config.kv_cache_groups) > 1:
1560
1561
1562
1563
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

1564
1565
1566
1567
1568
1569
1570
1571
1572
        if kv_cache_config.kv_cache_groups[
                0].kv_cache_spec.block_size != self.block_size:
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
                max_model_len=self.max_model_len,
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
1573
1574
1575
                block_sizes=[
                    kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
                ],
1576
1577
            )
        # Verify dtype compatibility between block_table_cpu and input_batch
1578
1579
1580
        assert self.block_table_cpu.dtype == self.input_batch.block_table[
            0].get_cpu_tensor().dtype

1581
1582
1583
1584
1585
1586
        kv_cache_sizes = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
            assert len(kv_cache_tensor.shared_by) == 1, (
                "KV cache tensor shared by multiple layers is not supported in "
                "TPU.")
            kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
1587

1588
        kv_caches: dict[str, torch.Tensor] = {}
1589
1590
1591
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
1592
1593
1594
                tensor_size = kv_cache_sizes[layer_name]
                assert tensor_size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_size // kv_cache_spec.page_size_bytes  # noqa
1595
                if isinstance(kv_cache_spec, AttentionSpec):
1596
1597
1598
1599
1600
1601
1602
1603
1604
                    if self.use_spmd:
                        num_kv_heads = kv_cache_spec.num_kv_heads
                        assert self.original_parallel_config is not None
                        tp_size = \
                            self.original_parallel_config.tensor_parallel_size
                        # TODO: Handle kv cache duplication under SPMD mode.
                        assert num_kv_heads % tp_size == 0, (
                            f"num_kv_heads {num_kv_heads} must be divisible by "
                            f"tp_size {tp_size} under SPMD mode")
1605
1606
1607
1608
1609
                    kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype

1610
                    tpu_kv_cache = torch.zeros(kv_cache_shape,
1611
                                               dtype=dtype).to(self.device)
1612

1613
                    kv_caches[layer_name] = tpu_kv_cache
1614
1615
                else:
                    raise NotImplementedError
1616

1617
1618
1619
1620
1621
1622
1623
1624
1625
        # Setup `kv_cache_config` and `kv_caches` for models
        # with cross-layer KV sharing
        if self.shared_kv_cache_layers:
            initialize_kv_cache_for_kv_sharing(
                self.shared_kv_cache_layers,
                kv_cache_config.kv_cache_groups,
                kv_caches,
            )

1626
1627
1628
1629
1630
        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
            self.kv_caches)

1631
1632
1633
1634
1635
        if self.use_spmd:
            # Shard KV Cache
            for cache in self.kv_caches:
                xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))

1636
1637
    def reset_dynamo_cache(self):
        if self.is_multimodal_model:
1638
            compiled_model = self.model.get_language_model().model
1639
1640
1641
1642
1643
1644
1645
        else:
            compiled_model = self.model.model
        if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
            logger.info("Clear dynamo cache and cached dynamo bytecode.")
            torch._dynamo.eval_frame.remove_from_cache(
                compiled_model.original_code_object)
            compiled_model.compiled_codes.clear()
1646

1647
1648
1649
1650
1651
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
    def select_hidden_states(self, hidden_states, indices_do_sample):
        return hidden_states[indices_do_sample]

    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1652
1653
1654
1655
    def compute_logits(self,
                       sample_hidden_states: torch.Tensor) -> torch.Tensor:
        return self.model.compute_logits(sample_hidden_states, None)

1656
1657
1658
    # TODO: Under SPMD mode, sample_from_logits has correctness issue.
    #       Re-enable the torch.compile once the issue is fixed in torchxla.
    # @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1659
1660
1661
    def sample_from_logits(
            self, logits: torch.Tensor,
            sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
1662
1663
1664
1665
        """
        Sample with xla-friendly function. This function is to be traced 
        separately from `forward` for lighter compilation overhead.
        """
1666
1667
1668
1669
1670
        if sampling_metadata.all_greedy:
            out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            out_tokens = self.sampler(logits,
                                      sampling_metadata).sampled_token_ids
1671
1672
        return out_tokens

1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
    def gather_logprobs(self, logits: torch.Tensor,
                        sampled_tokens: torch.Tensor) -> LogprobsTensors:
        """
        Gather the top_logprobs with corresponding tokens. Use a fixed number
        of logprobs as an alternative to having multiple pre-compiled graphs.
        Select the number of logprobs actually demanded by each request on CPU.
        """
        logprobs = self.sampler.compute_logprobs(logits)
        return self.sampler.gather_logprobs(
            logprobs,
            self.model_config.max_logprobs,
            token_ids=sampled_tokens.squeeze(-1))

1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
    def structured_decode(self, require_struct_decoding: torch.Tensor,
                          grammar_bitmask: torch.Tensor, logits: torch.Tensor,
                          arange: torch.Tensor) -> torch.Tensor:
        return torch.where(
            require_struct_decoding,
            self.apply_grammar_bitmask(logits, grammar_bitmask, arange),
            logits)

    def apply_grammar_bitmask(self, logits: torch.Tensor,
                              grammar_bitmask: torch.Tensor,
                              arange: torch.Tensor):
        assert (logits.shape[0] == grammar_bitmask.shape[0])
        logits_cloned = logits.clone()
        for i in range(logits.shape[0]):
            unpacked_bitmask = (torch.bitwise_right_shift(
                grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0
            unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size]
            logits_cloned[i] = logits_cloned[i].masked_fill(
                unpacked_bitmask, -float("inf"))
        return logits_cloned

1709
1710
    def get_multimodal_embeddings(self, *args, **kwargs):
        return self.model.get_multimodal_embeddings(*args, **kwargs)
1711

1712
1713
1714
    def get_input_embeddings(self, *args, **kwargs):
        return self.model.get_input_embeddings(*args, **kwargs)

1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
    def prepare_structured_decoding_input(
        self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        grammar_bitmask = scheduler_output.grammar_bitmask
        assert grammar_bitmask is not None
        num_reqs, _ = logits.shape

        # Reset pre-allocated tensors
        self.grammar_bitmask_cpu.zero_()
        self.require_structured_out_cpu.zero_()

        # We receive the structured output bitmask from the scheduler, but the
        # indices of the requests in the batch may not match the indices of
        # the bitmask since the scheduler doesn't know how the tpu runner is
        # ordering the requests in the batch. We need to match the order of
        # bitmask with the order of requests
        struct_out_indices: list[int] = []
        mask_indices: list[int] = []
        for req_id in self.input_batch.req_ids:
            mask_index = scheduler_output.structured_output_request_ids.get(
                req_id)
            if mask_index is None:
                continue
            batch_index = self.input_batch.req_id_to_index[req_id]
            struct_out_indices.append(batch_index)
            mask_indices.append(mask_index)
        self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy(
            grammar_bitmask[mask_indices])
        # It's not guaranteed that all requests in this batch require
        # structured output, so create a bool tensor to represent
        # the requests that need structured output.
        struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
        self.require_structured_out_cpu[struct_out_indices] = True
        return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
            self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
            self.structured_decode_arange.to(logits.device)

1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
    def _get_mm_dummy_batch(self, modality: str,
                            batch_size: int) -> BatchedTensorInputs:
        # Dummy data for pre-compiling multimodal models.
        dummy_request_data = self.mm_registry.get_decoder_dummy_data(
            model_config=self.model_config,
            seq_len=self.max_num_tokens,
        )
        dummy_mm_data = dummy_request_data.multi_modal_data

        # Dummy data definition in V0 may contain multiple multimodal items
        # (e.g, multiple images) for a single request, therefore here we
        # always replicate first item by max_num_mm_items times since in V1
        # they are scheduled to be processed separately.
        assert isinstance(dummy_mm_data, MultiModalKwargs), (
            "Expected dummy multimodal data to be of type "
            f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. "
            "This is most likely due to the model not having a merged "
            "processor.")

        # When models have a merged processor, their dummy data is
        # already batched `MultiModalKwargs`, therefore we take the first
        # `MultiModalKwargsItem` from the desired modality to profile on.
        dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
        dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])

        batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
                                                         batch_size)
1779
1780
1781
1782
        return MultiModalKwargs.as_kwargs(
            batched_dummy_mm_inputs,
            device=self.device,
        )
1783

1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795

def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
    logger.info("Preparing request paddings:")
    # assert min_req_size is power of 2
    assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
    paddings: list = []
    num = max(MIN_NUM_SEQS, min_req_size)
    while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
        paddings.append(num)
        logger.info("    %d", num)
        num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
    return paddings
1796
1797


1798
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
1799
    res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
1800
    return min(res, upper_limit)
1801
1802


1803
1804
def _get_token_paddings(min_token_size: int, max_token_size: int,
                        padding_gap: int) -> list[int]:
1805
1806
    """Generate a list of padding size, starting from min_token_size, 
    ending with a number that can cover max_token_size
1807

1808
1809
1810
    If padding_gap == 0 then:
        increase 2X each time (exponential)
    else:
1811
        first increase the size to twice,
1812
        then increase the padding size by padding_gap.
1813
    """
1814
1815
    # assert min_token_size is power of 2
    assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
1816
1817
    paddings = []
    num = min_token_size
1818
1819

    if padding_gap == 0:
1820
        logger.info("Using exponential token paddings:")
1821
        while True:
1822
1823
            logger.info("    %d", num)
            paddings.append(num)
1824
1825
            if num >= max_token_size:
                break
1826
1827
            num *= 2
    else:
1828
        logger.info("Using incremental token paddings:")
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
        while num <= padding_gap:
            logger.info("    %d", num)
            paddings.append(num)
            num *= 2
        num //= 2
        while num < max_token_size:
            num += padding_gap
            logger.info("    %d", num)
            paddings.append(num)

1839
1840
1841
1842
1843
1844
1845
1846
1847
    return paddings


def _get_padded_token_len(paddings: list[int], x: int) -> int:
    """Return the first element in paddings list greater or equal to x.
    """
    index = bisect.bisect_left(paddings, x)
    assert index < len(paddings)
    return paddings[index]
1848
1849


1850
1851
1852
def _get_padded_num_kv_cache_update_slices(
        num_tokens: int, max_num_reqs: int, page_size: int,
        num_slices_per_kv_cache_update_block: int) -> int:
1853
1854
1855
1856
1857
    """Calculates the padded number of KV cache update slices to avoid
    recompilation."""
    padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
    padded_num_slices = min(padded_num_slices, num_tokens)
    padded_num_slices = (
1858
1859
1860
        padded_num_slices + num_slices_per_kv_cache_update_block - 1
    ) // num_slices_per_kv_cache_update_block * \
        num_slices_per_kv_cache_update_block
1861
1862
1863
    return padded_num_slices


1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
    """Find the optimum number of slices to copy per Pallas program instance.

    Increasing the number of slices copied in one instance of the kernel program
    will increase HBM bandwidth utilization via more in-flight DMAs.

    However, it will also use more VMEM, and experimentally, we observed
    performance regression at 128 slices on v6e, likely due to running
    out of scalar registers. Thus this function will limit the number of
    slices to 64.
    """
1875
1876
1877
    # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
    # calculate num_slices_per_block based on 16MB in case any register spills.
    vmem_limit = 16 * 1024 * 1024
1878
1879
1880
1881
1882
1883
1884
1885
    num_slices_per_block = vmem_limit // page_size_bytes
    assert num_slices_per_block > 0, "Number of slices should be positive"
    num_slices_per_block = prev_power_of_2(num_slices_per_block)
    if num_slices_per_block > 64:
        num_slices_per_block = 64
    return num_slices_per_block


1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
def replace_set_lora(model):

    def _tpu_set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
        bias: Optional[torch.Tensor] = None,
    ):
        # TODO: The integer index leads to a recompilation, but converting it
        # to a tensor doesn't seem to work anymore. This might be fixed with a
        # later release of torch_xla.
        self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
        xm.mark_step()

    def _tpu_reset_lora(self, index: int):
        self._original_reset_lora(index)
        xm.mark_step()

    for _, module in model.named_modules():
        if isinstance(module, BaseLayerWithLoRA):
            module._original_set_lora = module.set_lora
            module._original_reset_lora = module.reset_lora
            module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
            module.reset_lora = _tpu_reset_lora.__get__(
                module, module.__class__)