tpu_model_runner.py 91.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import bisect
4
import gc
5
import time
6
from typing import TYPE_CHECKING, Any, cast
7
8
9
10
11
from unittest.mock import patch

import numpy as np
import torch
import torch.nn as nn
12

13
# TPU XLA related
14
import torch_xla
15
import torch_xla.core.xla_model as xm
16
import torch_xla.distributed.spmd as xs
17
18
import torch_xla.runtime as xr

19
import vllm.envs as envs
20
from vllm.attention import Attention
21
from vllm.attention.backends.abstract import AttentionType
22
from vllm.attention.layer import MLAAttention
23
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
24
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
25
26
27
28
29
30
31
from vllm.config import (
    ParallelConfig,
    VllmConfig,
    get_layers_from_vllm_config,
    update_config,
)
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
32
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
33
from vllm.forward_context import set_forward_context
34
from vllm.logger import init_logger
35
from vllm.lora.layers import BaseLayerWithLoRA
36
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
37
from vllm.model_executor.model_loader import get_model_loader
38
from vllm.model_executor.model_loader.tpu import TPUModelLoader
39
40
41
42
from vllm.model_executor.models.interfaces import (
    SupportsMultiModal,
    supports_transcription,
)
43
from vllm.model_executor.models.interfaces_base import (
44
45
46
    is_pooling_model,
    is_text_generation_model,
)
47
from vllm.multimodal import MULTIMODAL_REGISTRY
48
49
50
51
52
from vllm.multimodal.inputs import (
    BatchedTensorInputs,
    MultiModalKwargsItem,
    PlaceholderRange,
)
53
from vllm.multimodal.utils import group_mm_kwargs_by_modality
54
from vllm.sequence import IntermediateTensors
55
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
56
from vllm.utils.math_utils import cdiv, prev_power_of_2
57
from vllm.utils.platform_utils import is_pin_memory_available
58
59
60
61
62
63
64
65
66
67
68
from vllm.v1.attention.backends.pallas import (
    TPU_STR_DTYPE_TO_TORCH_DTYPE,
    PallasAttentionBackend,
    PallasMetadata,
    get_page_size_bytes,
)
from vllm.v1.kv_cache_interface import (
    AttentionSpec,
    FullAttentionSpec,
    KVCacheConfig,
    KVCacheSpec,
69
    MLAAttentionSpec,
70
71
72
73
74
75
76
77
    SlidingWindowSpec,
)
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    LogprobsLists,
    LogprobsTensors,
    ModelRunnerOutput,
)
78
79
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
80
from vllm.v1.worker.kv_connector_model_runner_mixin import (
81
82
83
    KVConnectorModelRunnerMixin,
    KVConnectorOutput,
)
84
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
85
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
86

87
88
89
90
91
92
from .utils import (
    MultiModalBudget,
    add_kv_sharing_layers_to_kv_cache_groups,
    bind_kv_cache,
    sanity_check_mm_encoder_outputs,
)
93

94
if TYPE_CHECKING:
95
    from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
96
97
98

logger = init_logger(__name__)

99
INVALID_TOKEN_ID = -1
100
101
# Smallest output size
MIN_NUM_SEQS = 8
102
103


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#########################################################
# Ways to avoid recompilation
#########################################################
#
# The model executor has two primary components:
# 1. preparing the model and sampler inputs
# 2. executing the model and sampler.
# The core idea is to avoid any TPU computation during input preparation. For
# better compilation tracking and increased flexibility, the model execution and
# sampler are divided into several distinct components.
#
# Below are the detailed steps:
#
# Step 1
# It is recommended to avoid TPU operations when preparing the model and sampler
# inputs. CPU tensors can be prepared and transferred to the XLA device using
# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
# compilation.
#
# Step 2
# The TPU execution should be decomposed into subgraphs (4 at the moment):
# 1. the main model
# 2. selecting hidden states for each request
# 3. sampler
# 4. encoder.
# Each subgraph should be decorated in a torch.compile. This is used to make
# sure that we have the same subgraph topology in both dummy_run and
# xecute_model. The results from these subgraphs should either be passed to
# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
# subsequent processing on the CPU.
#
# Step 3
# The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate
# pre-compilation.
139
class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
140
141
142
143
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
144
        original_parallel_config: ParallelConfig | None = None,
145
146
147
148
149
150
151
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
152
        self.original_parallel_config = original_parallel_config
153
154
155
156
157
158
159
160
161
162
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
        self.device_config = vllm_config.device_config

        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
        self.device = device
163
        self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
164

165
166
167
168
169
170
        # SPMD Related
        self.use_spmd = envs.VLLM_XLA_USE_SPMD
        if self.use_spmd:
            num_devices = xr.global_runtime_device_count()
            mesh_shape = (num_devices, 1)
            device_ids = np.array(range(num_devices))
171
            self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
172

173
        self.enforce_eager = model_config.enforce_eager
174
175
176
177

        self.num_xla_graphs = 0
        self._update_num_xla_graphs("init")

178
179
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
180
        if cache_config.cache_dtype == "auto":
181
182
            model_dtype = self.dtype
            if isinstance(model_dtype, str):
183
                self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
184
185
            else:
                self.kv_cache_dtype = model_dtype
186
        else:
187
            self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
188
        self._hidden_states_dtype = self.dtype
189
190
191
192

        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
193
        self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
194
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
195
196
197
198
199
        self.num_blocks_per_most_len_req = (
            cdiv(self.most_model_len, self.block_size)
            if self.most_model_len is not None
            else None
        )
200
201
        # InputBatch needs to work with sampling tensors greater than padding
        # to avoid dynamic shapes. Also, avoid suboptimal alignment.
202
        self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
203
204
205
        self.num_tokens_paddings = _get_token_paddings(
            min_token_size=16,
            max_token_size=scheduler_config.max_num_batched_tokens,
206
207
            padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP,
        )
208
209
210
        # In case `max_num_tokens < max(num_tokens_paddings)` use the actual
        # padded max value to pre-allocate data structures and pre-compile.
        self.max_num_tokens = self.num_tokens_paddings[-1]
211
212
213

        # Model-related.
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
214
            parallel_config, "attention"
215
216
        )
        self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
217
218
219
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
        self.hidden_size = model_config.get_hidden_size()
220
        self.vocab_size = model_config.get_vocab_size()
221

222
223
224
        if self.lora_config is not None:
            self.vocab_size += self.lora_config.lora_extra_vocab_size

225
226
227
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.uses_mrope = model_config.uses_mrope
228
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
229
230
            model_config
        )
231
232
233
        # TODO: Support M-RoPE (e.g, Qwen2-VL)
        assert not self.uses_mrope, "TPU does not support M-RoPE yet."

234
235
236
237
238
239
240
241
242
243
        self._num_slices_per_kv_cache_update_block = (
            _get_num_slices_per_kv_cache_update_block(
                get_page_size_bytes(
                    block_size=self.block_size,
                    num_kv_heads=self.num_kv_heads,
                    head_size=self.head_size,
                    kv_cache_dtype=self.kv_cache_dtype,
                )
            )
        )
244

245
        # Lazy initialization
246
        self.model: nn.Module  # Set after load_model
247
        self.kv_caches: list[torch.Tensor] = []
248
249
        # mm_hash -> encoder_output
        self.encoder_cache: dict[str, torch.Tensor] = {}
250
251
252

        # Request states.
        self.requests: dict[str, CachedRequestState] = {}
253

254
255
256
257
258
259
260
261
        # Initialize input batch early to avoid AttributeError in _update_states
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
            vocab_size=self.model_config.get_vocab_size(),
262
            block_sizes=[self.block_size],
263
            kernel_block_sizes=[self.cache_config.block_size],
264
265
        )

266
267
268
        # Cached torch/numpy tensor
        # The pytorch tensor and numpy array share the same buffer.
        # Sometimes the numpy op is faster so we create both.
269
270
271
        self.input_ids_cpu = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device="cpu"
        )
272

273
274
275
        self.positions_cpu = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device="cpu"
        )
276
277
        self.positions_np = self.positions_cpu.numpy()
        self.block_table_cpu = torch.zeros(
278
            (self.max_num_reqs, self.max_num_blocks_per_req),
279
            dtype=torch.int32,
280
281
            device="cpu",
        )
282
        # adjust num_reqs to avoid SMEM OOM.
283
284
285
286
287
288
289
290
291
292
        self.num_reqs_most_model_len = (
            min(
                PallasAttentionBackend.get_max_num_seqs(
                    self.most_model_len, self.block_size
                ),
                self.max_num_reqs,
            )
            if self.most_model_len is not None
            else None
        )
293
        self.num_reqs_max_model_len = min(
294
295
296
297
298
299
300
301
302
303
304
            PallasAttentionBackend.get_max_num_seqs(
                self.max_model_len, self.block_size
            ),
            self.max_num_reqs,
        )
        self.query_start_loc_cpu = torch.zeros(
            self.max_num_tokens + 1,
            dtype=torch.int32,
            device="cpu",
            pin_memory=self.pin_memory,
        )
305
306
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()

307
308
309
310
311
312
        self.seq_lens_cpu = torch.zeros(
            self.max_num_tokens,
            dtype=torch.int32,
            device="cpu",
            pin_memory=self.pin_memory,
        )
313
        self.seq_lens_np = self.seq_lens_cpu.numpy()
314

315
316
        # Only relevant for multimodal models
        if self.supports_mm_inputs:
317
318
319
320
321
322
            self.is_mm_embed_cpu = torch.zeros(
                self.max_num_tokens,
                dtype=torch.bool,
                device="cpu",
                pin_memory=self.pin_memory,
            )
323

324
325
        # Range tensor with values [0 .. self.max_num_tokens - 1].
        # Used to initialize positions / context_lens / seq_lens
326
327
        # Keep in int64 to avoid overflow with long context
        self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64)
328
        self.num_reqs_paddings = _get_req_paddings(
329
330
            min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs
        )
331

332
333
334
335
336
337
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}

338
339
340
341
342
        # tensors for structured decoding
        self.grammar_bitmask_cpu = torch.zeros(
            (self.max_num_reqs, cdiv(self.vocab_size, 32)),
            dtype=torch.int32,
            device="cpu",
343
344
            pin_memory=self.pin_memory,
        )
345
346
347
348
        self.require_structured_out_cpu = torch.zeros(
            (self.max_num_reqs, 1),
            dtype=torch.bool,
            device="cpu",
349
350
            pin_memory=self.pin_memory,
        )
351
        self.structured_decode_arange = torch.arange(
352
353
            0, 32, device="cpu", pin_memory=self.pin_memory
        )
354

355
356
357
358
359
360
361
362
363
        self.mm_budget = (
            MultiModalBudget(
                self.model_config,
                self.scheduler_config,
                self.mm_registry,
            )
            if self.supports_mm_inputs
            else None
        )
364

365
366
367
368
369
        if not self.use_spmd:
            self.sample_from_logits_func = torch.compile(
                self.sample_from_logits,
                backend="openxla",
                fullgraph=True,
370
371
                dynamic=False,
            )
372
373
374
        else:
            self.sample_from_logits_func = self.sample_from_logits

375
376
377
378
379
        # For passing scheduler_output between successive
        # execute_model() and sample_tokens() calls.
        self.scheduler_output: SchedulerOutput | None = None
        self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None

380
381
382
383
    def reset_mm_cache(self) -> None:
        if self.mm_budget:
            self.mm_budget.reset_cache()

384
385
386
387
388
389
390
391
392
393
    def _update_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        total_cached_graphs = xr.get_num_cached_compilation_graph()
        new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
        if new_compiled_graphs == 0:
            return

394
395
396
        logger.info(
            "Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str
        )
397
398
399
400
401
402
403
404
405
406
407
        self.num_xla_graphs += new_compiled_graphs

    def _verify_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        curr_cached_graph = xr.get_num_cached_compilation_graph()
        assert self.num_xla_graphs == curr_cached_graph, (
            "Recompilation after warm up is detected during {}."
            " num_xla_graphs = {} curr_cached_graph = {}".format(
408
409
410
                case_str, self.num_xla_graphs, curr_cached_graph
            )
        )
411

412
413
414
415
416
417
418
419
    def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        Returns:
420
            True if there is a new/resumed/paused/finished request.
421
422
423
424
425
426
427
428
429
430
431
432
            If False, we can skip copying SamplingMetadata to the GPU.
        """
        # Remove finished requests from the cached states.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)

        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
433
        removed_req_indices: list[int] = []
434
435
436
437
438
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)

439
        # Free the cached encoder outputs.
440
441
        for mm_hash in scheduler_output.free_encoder_mm_hashes:
            self.encoder_cache.pop(mm_hash, None)
442

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            assert req_index is not None
            removed_req_indices.append(req_index)

460
        req_ids_to_add: list[str] = []
461
462
        # Add new requests to the cached states.
        for new_req_data in scheduler_output.scheduled_new_reqs:
463
            assert new_req_data.sampling_params is not None, (
464
                "Pooling is not supported in TPU yet"
465
            )
466
467
468
469
470
471
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params

            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
                prompt_token_ids=new_req_data.prompt_token_ids,
472
                prompt_embeds=new_req_data.prompt_embeds,
473
                mm_features=new_req_data.mm_features,
474
                sampling_params=sampling_params,
475
                pooling_params=None,
476
                generator=None,
477
478
479
480
481
482
483
484
485
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                output_token_ids=[],
                lora_request=new_req_data.lora_request,
            )

            req_ids_to_add.append(req_id)

        # Update the states of the running/resumed requests.
486
487
        req_data = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(req_data.req_ids):
488
            req_state = self.requests[req_id]
489
490
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
491
            resumed_from_preemption = req_id in req_data.resumed_req_ids
492
493

            # Update the cached states.
494
495
            req_state.num_computed_tokens = num_computed_tokens
            if not resumed_from_preemption:
496
497
                if new_block_ids is not None:
                    # Append the new blocks to the existing block IDs.
498
                    for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
499
                        block_ids.extend(new_ids)
500
            else:
501
                assert new_block_ids is not None
502
503
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
504
                req_state.block_ids = new_block_ids
505
506
507
508
509
510
511
512
513
514

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
515
            self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
516
            if new_block_ids is not None:
517
                self.input_batch.block_table.append_row(new_block_ids, req_index)
518
519
520
521
522
523

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
524
525
            # Fill the empty index or append to the end
            req_index = removed_req_indices.pop() if removed_req_indices else None
526
527
528
529
530
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
531

532
533
534
535
536
        return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0

    def get_model(self) -> nn.Module:
        return self.model

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
    def get_supported_generation_tasks(self) -> list[GenerationTask]:
        model = self.get_model()
        supported_tasks = list[GenerationTask]()

        if is_text_generation_model(model):
            supported_tasks.append("generate")

        if supports_transcription(model):
            if model.supports_transcription_only:
                return ["transcription"]

            supported_tasks.append("transcription")

        return supported_tasks

552
553
554
555
556
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        model = self.get_model()
        if not is_pooling_model(model):
            return []

557
        return list(model.pooler.get_supported_tasks())
558

559
560
561
562
563
564
565
566
567
568
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks = list[SupportedTask]()

        if self.model_config.runner_type == "generate":
            tasks.extend(self.get_supported_generation_tasks())
        if self.model_config.runner_type == "pooling":
            tasks.extend(self.get_supported_pooling_tasks())

        return tuple(tasks)

569
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
570
        """
571
        Generates the KVCacheSpec by parsing the kv cache format from each
572
573
        Attention module in the static forward context.
        Returns:
574
            KVCacheSpec: A dictionary mapping layer names to their KV cache
575
576
577
            format. Layers that do not need KV cache are not included.
        """

578
        layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
579
        block_size = self.vllm_config.cache_config.block_size
580
581
        cache_dtype_str = self.vllm_config.cache_config.cache_dtype

582
        kv_cache_spec: dict[str, KVCacheSpec] = {}
583
        for layer_name, attn_module in layers.items():
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
            # Classic Attention path
            if isinstance(attn_module, Attention):
                if (
                    kv_tgt_layer := attn_module.kv_sharing_target_layer_name
                ) is not None:
                    # The layer doesn't need its own KV cache and will use that of
                    # the target layer. We skip creating a KVCacheSpec for it, so
                    # that KV cache management logic will act as this layer does
                    # not exist, and doesn't allocate KV cache for the layer. This
                    # enables the memory saving of cross-layer kv sharing, allowing
                    # a given amount of memory to accommodate longer context lengths
                    # or enable more requests to be processed simultaneously.
                    self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                    continue

                if attn_module.attn_type == AttentionType.DECODER:
                    if isinstance(attn_module, ChunkedLocalAttention):
                        logger.warning_once(
                            "Using irope in Pallas is not supported yet, it "
                            "will fall back to global attention for long context."
                        )
                    if attn_module.sliding_window is not None:
                        kv_cache_spec[layer_name] = SlidingWindowSpec(
                            block_size=block_size,
                            num_kv_heads=attn_module.num_kv_heads,
                            head_size=attn_module.head_size,
                            dtype=self.kv_cache_dtype,
                            sliding_window=attn_module.sliding_window,
                        )
                    else:
                        kv_cache_spec[layer_name] = FullAttentionSpec(
                            block_size=block_size,
                            num_kv_heads=attn_module.num_kv_heads,
                            head_size=attn_module.head_size,
                            dtype=self.kv_cache_dtype,
                        )
                elif attn_module.attn_type in (
                    AttentionType.ENCODER,
                    AttentionType.ENCODER_ONLY,
                ):
                    # encoder-only attention does not need KV cache.
                    continue
                elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                    raise NotImplementedError
628
                else:
629
630
631
632
633
634
635
636
637
638
639
640
                    raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
            # MLAAttention path
            elif isinstance(attn_module, MLAAttention):
                if layer_name in kv_cache_spec:
                    continue
                kv_cache_spec[layer_name] = MLAAttentionSpec(
                    block_size=block_size,
                    num_kv_heads=1,
                    head_size=attn_module.head_size,
                    dtype=self.kv_cache_dtype,
                    cache_dtype_str=cache_dtype_str,
                )
641
            else:
642
                continue
643
644
645

        return kv_cache_spec

646
647
648
    def _get_slot_mapping_metadata(
        self, num_reqs, num_scheduled_tokens_per_req
    ) -> np.ndarray:
649
650
651
652
653
654
655
656
657
658
659
660
        """
        Computes metadata for mapping slots to blocks in the key-value (KV)
        cache for a batch of requests.

        This function determines, for each request in the batch, how the
        scheduled tokens are distributed across memory blocks, and generates
        metadata needed to map slices of tokens to their corresponding positions
        in the KV cache.

        Args:
            num_reqs (int): Number of requests in the current batch.
            num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
661
                to be scheduled for each request.
662
663
664

        Returns:
            np.ndarray: A 2D array of shape (total_block_len, 3), where each row
665
                contains:
666
                - kv_cache_start_index (int): The starting index in the KV cache
667
                  for the corresponding slice.
668
                - new_kv_start_index (int): The starting index in the new KV
669
                  cache for the corresponding slice.
670
671
672
                - slice_len (int): The length of the slice.
        """
        slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
673
674
675
676
        slices_end = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs]
            + num_scheduled_tokens_per_req
        )
677
678
679
680
        local_block_start_idx = slices_start // self.block_size
        local_block_end_idx = (slices_end - 1) // self.block_size
        no_repeat_req_indices = self.arange_np[:num_reqs]
        global_block_start_idx = (
681
682
            no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx
        )
683
684
685
686
687
688
689
        block_lens = local_block_end_idx - local_block_start_idx + 1
        global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
        slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
        global_block_indices = global_block_start_idx + slice_arange
        block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
        total_block_len = np.sum(block_lens)
690
691
692
        slot_mapping_slices = np.repeat(
            np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0
        )
693
694
695
        cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
        np.cumsum(block_lens, out=cu_block_lens[1:])
        for req_idx in range(num_reqs):
696
697
698
699
700
701
            slot_mapping_slices[cu_block_lens[req_idx]][0] = (
                slices_start[req_idx] % self.block_size
            )
            slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = (
                slices_end[req_idx] - 1
            ) % self.block_size + 1
702
703
704
        slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
        cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
        np.cumsum(slice_lens, out=cu_slices_lens[1:])
705
706
707
        kv_cache_start_indices = slot_mapping_slices[:, 0] + (
            block_numbers * self.block_size
        )
708
709
        new_kv_start_indices = cu_slices_lens[:-1]
        slot_mapping_metadata = np.stack(
710
711
            [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1
        )
712
713
        return slot_mapping_metadata

714
    def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int):
715
        assert scheduler_output.total_num_scheduled_tokens > 0
716
717
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0
718
        assert start_index < num_reqs
719

720
        # Get the number of scheduled tokens for each request.
721
        use_max_model_len = self.most_model_len is None
722
723
        num_scheduled_tokens_per_req = []
        max_num_scheduled_tokens_all_reqs = 0
724
725
726
727
728
        end_index = start_index

        # Use either most_model_len or max_model_len depending on request size.
        for i in range(start_index, num_reqs):
            req_id = self.input_batch.req_ids[i]
729
            assert req_id is not None
730
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
731
732
            if not use_max_model_len and num_tokens > self.most_model_len:
                use_max_model_len = True
733
            num_scheduled_tokens_per_req.append(num_tokens)
734
735
        if use_max_model_len:
            if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
736
737
738
                num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
                    : self.num_reqs_max_model_len
                ]
739
740
741
742
                end_index = start_index + self.num_reqs_max_model_len
            else:
                end_index = num_reqs
        else:
743
744
745
746
            if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len:
                num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
                    : self.num_reqs_most_model_len
                ]
747
748
749
750
                end_index = start_index + self.num_reqs_most_model_len
            else:
                end_index = num_reqs
        max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
751
752
753
        num_scheduled_tokens_per_req = np.array(
            num_scheduled_tokens_per_req, dtype=np.int32
        )
754
        total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
755
756
        assert max_num_scheduled_tokens_all_reqs > 0

757
758
        num_reqs = len(num_scheduled_tokens_per_req)

759
760
761
        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        # For each scheduled token, what are the corresponding req index.
762
        req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req)
763
764
765
766
767

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # For each scheduled token, what is its position in corresponding req.
        arange = np.concatenate(
768
769
            [self.arange_np[:n] for n in num_scheduled_tokens_per_req]
        )
770
771
772

        # Get positions.
        positions_np = self.positions_np[:total_num_scheduled_tokens]
773
774
775
776
777
        np.add(
            self.input_batch.num_computed_tokens_cpu[req_indices],
            arange,
            out=positions_np,
        )
778
779
780
781
782

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
783
784
785
        token_indices = (
            positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
        )
786
787
788
789

        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
790
791
792
793
794
795
        torch.index_select(
            self.input_batch.token_ids_cpu_tensor.flatten(),
            0,
            torch.from_numpy(token_indices),
            out=self.input_ids_cpu[:total_num_scheduled_tokens],
        )
796
797
798

        # Prepare the attention metadata.
        self.query_start_loc_np[0] = 0
799
800
801
802
        np.cumsum(
            num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1]
        )
        self.query_start_loc_np[num_reqs + 1 :] = 1
803
804

        self.seq_lens_np[:num_reqs] = (
805
806
807
            self.input_batch.num_computed_tokens_cpu[:num_reqs]
            + num_scheduled_tokens_per_req
        )
808
809

        # Do the padding and copy the tensors to the TPU.
810
        padded_total_num_scheduled_tokens = _get_padded_token_len(
811
812
            self.num_tokens_paddings, total_num_scheduled_tokens
        )
813
814
        # Zero out to avoid spurious values from prev iteration (last cp chunk)
        self.input_ids_cpu[
815
816
817
818
819
820
821
822
            total_num_scheduled_tokens:padded_total_num_scheduled_tokens
        ] = 0
        self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to(
            self.device
        )
        self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(
            self.device
        )
823
        if use_max_model_len:
824
825
826
827
828
829
830
831
832
833
            block_tables = self.block_table_cpu[
                : self.num_reqs_max_model_len, : self.max_num_blocks_per_req
            ]
            block_tables[:num_reqs, : self.max_num_blocks_per_req] = (
                self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]
            )
            query_start_loc = self.query_start_loc_cpu[
                : self.num_reqs_max_model_len + 1
            ].to(self.device)
            seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device)
834
        else:
835
836
837
838
839
840
841
842
843
844
845
846
            block_tables = self.block_table_cpu[
                : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req
            ]
            block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = (
                self.input_batch.block_table[0].get_cpu_tensor()[
                    :num_reqs, : self.num_blocks_per_most_len_req
                ]
            )
            query_start_loc = self.query_start_loc_cpu[
                : self.num_reqs_most_model_len + 1
            ].to(self.device)
            seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device)
847
        block_tables = block_tables.to(self.device)
848

849
        # Calculate the slot mapping
850
        slot_mapping_metadata = self._get_slot_mapping_metadata(
851
852
            num_reqs, num_scheduled_tokens_per_req
        )
853
        num_kv_update_slices = slot_mapping_metadata.shape[0]
854
        padded_num_slices = _get_padded_num_kv_cache_update_slices(
855
856
            padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size
        )
857
858
859
        slot_mapping_metadata = np.pad(
            slot_mapping_metadata,
            [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
860
861
            constant_values=0,
        )
862
        slot_mapping_metadata = np.transpose(slot_mapping_metadata)
863
        slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device)
864

865
866
867
868
869
        if self.lora_config is not None:
            # We need to respect padding when activating LoRA adapters
            padded_num_scheduled_tokens_per_req = np.copy(
                num_scheduled_tokens_per_req
            )  # Copying to avoid accidental state corruption bugs
870
            padded_num_scheduled_tokens_per_req[-1] += (
871
                padded_total_num_scheduled_tokens - total_num_scheduled_tokens
872
            )
873

874
            self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req)
875

876
        attn_metadata = PallasMetadata(
877
            slot_mapping=slot_mapping_metadata,
878
            block_tables=block_tables,
879
880
            context_lens=seq_lens,
            query_start_loc=query_start_loc,
881
882
883
884
885
            num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device),
            num_kv_update_slices=torch.tensor(
                [num_kv_update_slices], dtype=torch.int32, device=self.device
            ),
            num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block,
886
        )
887
888
889
890
891
        # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
        # request in the batch. While we should not sample any token from this
        # partial request, we do so for simplicity. We will ignore the sampled
        # token from the partial request.
        # TODO: Support prompt logprobs.
892
        padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
893
894
            num_reqs, self.max_num_reqs
        )
895
896
        # Indices at which we sample (positions of last token in the sequence).
        # Padded to avoid recompiling when `num_reqs` varies.
897
        logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1
898
        logits_indices = logits_indices.to(self.device)
899

900
901
902
903
904
        if self.lora_config is not None:
            # We need to respect padding when activating LoRA adapters
            padded_num_scheduled_tokens_per_req = np.copy(
                num_scheduled_tokens_per_req
            )  # Copying to avoid accidental state corruption bugs
905
            padded_num_scheduled_tokens_per_req[-1] += (
906
                padded_total_num_scheduled_tokens - total_num_scheduled_tokens
907
            )
908

909
            self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req)
910

911
        layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys()
912
        per_layer_attn_metadata = {
913
            layer_name: attn_metadata for layer_name in layer_names
914
        }
915
916
917
918
919
920
921
        return (
            per_layer_attn_metadata,
            logits_indices,
            padded_num_reqs,
            num_reqs,
            end_index,
        )
922

923
    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
924
925
926
927
928
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
929
        mm_kwargs = list[MultiModalKwargsItem]()
930
931
        # List of tuple (mm_hash, pos_info)
        mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
932
933
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
934
935

            for mm_input_id in encoder_input_ids:
936
937
938
939
                mm_feature = req_state.mm_features[mm_input_id]
                mm_hash = mm_feature.identifier
                mm_kwargs.append(mm_feature.data)
                mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
940
941
942
943
944
945
946
947

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
948
        model = cast(SupportsMultiModal, self.model)
949
        encoder_outputs = []
950
        for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
951
952
953
954
            mm_kwargs,
            device=self.device,
            pin_memory=self.pin_memory,
            merge_by_field_config=model.merge_by_field_config,
955
        ):
956
957
958
959
960
961
962
            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
963
            torch_xla.sync(wait=False)
964
            curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group)
965
            torch_xla.sync(wait=False)
966

967
968
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
969
                expected_num_items=num_items,
970
971
            )

972
973
974
975
976
977
            if isinstance(curr_group_outputs, torch.Tensor):
                encoder_outputs.append(curr_group_outputs)
            else:
                assert isinstance(curr_group_outputs, (list, tuple))
                for output in curr_group_outputs:
                    encoder_outputs.append(output)
978
979

        # Cache the encoder outputs.
980
981
982
        # NOTE (NickLucche) here we diverge from logic in other runners, as we
        # assume to only have whole mm items to process. Hence we avoid the
        # intrinsic dynamism that `scatter_mm_placeholders` introduces.
983
        for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
984
985
986
            assert pos_info.is_embed is None, (
                "Expected all positions to be contiguous and embeddings."
            )
987
            self.encoder_cache[mm_hash] = output
988
989

    def _gather_mm_embeddings(
990
991
        self,
        scheduler_output: "SchedulerOutput",
992
993
994
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        padded_total_num_scheduled_tokens = _get_padded_token_len(
995
996
            self.num_tokens_paddings, total_num_scheduled_tokens
        )
997
998
999
1000
1001
1002

        is_mm_embed = self.is_mm_embed_cpu
        is_mm_embed[:padded_total_num_scheduled_tokens] = False
        mm_embeds = list[torch.Tensor]()
        req_start_idx = 0

1003
        for req_id in self.input_batch.req_ids:
1004
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1005
1006
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
1007

1008
1009
1010
1011
            # TODO unroll loop and assume/enforce --disable_chunked_mm_input
            # NOTE (NickLucche) here we diverge from logic in other runners, as
            # we assume to only have whole mm items to process. Hence we avoid
            # the intrinsic dynamism that `gather_mm_placeholders` introduces.
1012
1013
            for mm_feature in req_state.mm_features:
                pos_info = mm_feature.mm_position
1014
1015
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue
1028
1029
1030
1031
1032
1033
1034
1035

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens,
                )
                assert start_idx < end_idx

1036
                mm_hash = mm_feature.identifier
1037
                encoder_output = self.encoder_cache.get(mm_hash, None)
1038
                assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
1039

1040
1041
1042
                assert pos_info.is_embed is None, (
                    "Expected all positions to be contiguous and embeddings."
                )
1043
1044

                req_start_pos = req_start_idx + start_pos - num_computed_tokens
1045
                is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True
1046
1047

                # Only whole mm items are processed
1048
                mm_embeds.append(encoder_output)
1049

1050
1051
            req_start_idx += num_scheduled_tokens

1052
        is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device)
1053
1054
1055
1056
1057
1058

        return mm_embeds, is_mm_embed

    def _get_model_inputs(
        self,
        input_ids: torch.Tensor,
1059
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
1060
    ):
1061
        if self.supports_mm_inputs:
1062
1063
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

1064
1065
1066
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
1067
            inputs_embeds = self.model.get_input_embeddings(
1068
                input_ids,
1069
                multimodal_embeddings=mm_embeds,
1070
                is_multimodal=is_mm_embed,
1071
            )
1072

1073
1074
1075
1076
1077
1078
1079
1080
            return None, inputs_embeds
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            return input_ids, None

1081
1082
1083
1084
    @torch.no_grad()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
1085
        intermediate_tensors: IntermediateTensors | None = None,
1086
1087
1088
1089
1090
1091
    ) -> ModelRunnerOutput | None:
        if self.scheduler_output is not None:
            raise RuntimeError(
                "State error: sample_tokens() must be called "
                "after execute_model() returns None."
            )
1092
1093
        # Update cached state
        self._update_states(scheduler_output)
1094
        if not scheduler_output.total_num_scheduled_tokens:
1095
1096
1097
1098
            if not has_kv_transfer_group():
                # Return empty ModelRunnerOutput if there's no work to do.
                return EMPTY_MODEL_RUNNER_OUTPUT

1099
            return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
1100

1101
        mm_embed_inputs = None
1102
        if self.supports_mm_inputs:
1103
            # Run the multimodal encoder if any.
1104
            self._execute_mm_encoder(scheduler_output)
1105
1106
            mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)

1107
        torch_xla.sync(wait=False)
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124

        self.scheduler_output = scheduler_output
        self.mm_embed_inputs = mm_embed_inputs
        return None

    @torch.no_grad()
    def sample_tokens(
        self, grammar_output: "GrammarOutput | None"
    ) -> ModelRunnerOutput:
        if self.scheduler_output is None:
            # Nothing to do (PP non-final rank case), output isn't used.
            return None  # noqa
        scheduler_output = self.scheduler_output
        mm_embed_inputs = self.mm_embed_inputs
        self.scheduler_output = None
        self.mm_embed_inputs = None

1125
        # Prepare inputs, the requests might be split into multiple
1126
1127
1128
1129
        # executions, combine the result of each execution.
        start_index = 0
        combined_selected_tokens: list[torch.Tensor] = []
        combined_logprobs: list[LogprobsLists] = []
1130
1131
1132
1133
1134
1135

        # NOTE: setup current batch's metadata for kv connector.
        # Currently, only verified with NixlConnector
        with set_forward_context(None, self.vllm_config):
            self.maybe_setup_kv_connector(scheduler_output)

1136
        while start_index < self.input_batch.num_reqs:
1137
1138
1139
            attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = (
                self._prepare_inputs(scheduler_output, start_index)
            )
1140
            input_ids, inputs_embeds = self._get_model_inputs(
1141
1142
                self.input_ids, mm_embed_inputs
            )
1143
            torch_xla.sync(wait=False)
1144
1145
            # Run the decoder
            with set_forward_context(
1146
1147
1148
1149
                attn_metadata,
                self.vllm_config,
                num_tokens=scheduler_output.total_num_scheduled_tokens,
            ):
1150
1151
1152
1153
1154
                hidden_states = self.model(
                    input_ids=input_ids,
                    positions=self.position_ids,
                    inputs_embeds=inputs_embeds,
                )
1155
            hidden_states = self.select_hidden_states(hidden_states, logits_indices)
1156
            logits = self.compute_logits(hidden_states)
1157
1158
1159
            tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
                self.input_batch, padded_num_reqs, self.device
            )
1160
            if grammar_output is not None:
1161
                require_struct_decoding, grammar_bitmask_padded, arange = (
1162
                    self.prepare_structured_decoding_input(logits, grammar_output)
1163
1164
1165
1166
                )
                logits = self.structured_decode(
                    require_struct_decoding, grammar_bitmask_padded, logits, arange
                )
1167
            selected_token_ids = self.sample_from_logits_func(
1168
1169
                logits, tpu_sampling_metadata
            )
1170
1171
1172
1173
            # NOTE (NickLucche) Use the original logits (before any penalties or
            # temperature scaling) for the top-k logprobs. We can't enforce it
            # due to recompilations outside torch.compiled code, so just make
            # sure `sample_from_logits` does not modify the logits in-place.
1174
1175
1176
1177
1178
            logprobs = (
                self.gather_logprobs(logits, selected_token_ids)
                if tpu_sampling_metadata.logprobs
                else None
            )
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188

            # Remove padding on cpu and keep dynamic op outside of xla graph.
            selected_token_ids = selected_token_ids.cpu()[:num_reqs]

            combined_selected_tokens.append(selected_token_ids)
            if tpu_sampling_metadata.logprobs:
                combined_logprobs.append(logprobs.tolists())

            start_index = end_index

1189
1190
1191
1192
1193
        # NOTE: current kv load and save get h2d/d2h copies involved.
        # Those copies are blocking. Once they become async., kv_save
        # should be called right after each single forward pass,
        # instead of the forwards of the entire input batch.
        self.maybe_wait_for_kv_save()
1194
1195
1196
        finished_sending, finished_recving = self.get_finished_kv_transfers(
            scheduler_output
        )
1197

1198
1199
1200
1201
1202
1203
1204
1205
1206
        selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
        if tpu_sampling_metadata.logprobs:

            def concat_lists(input_lists):
                result = []
                for input_list in input_lists:
                    result.extend(input_list)
                return result

1207
1208
1209
1210
1211
1212
1213
1214
1215
            logprobs_lists = LogprobsLists(
                logprob_token_ids=concat_lists(
                    [lp.logprob_token_ids for lp in combined_logprobs]
                ),
                logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]),
                sampled_token_ranks=concat_lists(
                    [lp.sampled_token_ranks for lp in combined_logprobs]
                ),
            )
1216
1217
        else:
            logprobs_lists = None
1218

1219
1220
        # Update the cache state concurrently. Code above will not block until
        # we use `selected_token_ids`. Add mark_step if post-processing changes
1221
        request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
1222
        discard_sampled_tokens_req_indices = []
1223
        num_reqs = self.input_batch.num_reqs
1224
1225
1226
        for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
            assert req_id is not None
            req_state = self.requests[req_id]
1227
1228
1229
1230
            seq_len = (
                req_state.num_computed_tokens
                + scheduler_output.num_scheduled_tokens[req_id]
            )
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
            if seq_len >= req_state.num_tokens:
                request_seq_lens.append((i, req_state, seq_len))
            else:
                # Ignore the sampled token from the partial request.
                # Rewind the generator state as if the token was not sampled.
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)

1241
1242
1243
1244
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

1245
        assert all(
1246
1247
            req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]
        ), "req_ids contains None"
1248
        req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
1249

1250
        prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
1251
        for req_id in self.input_batch.req_ids[:num_reqs]:
1252
1253
            prompt_logprobs_dict[req_id] = None

1254
1255
1256
        max_gen_len = selected_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_sampled_token_ids = selected_token_ids.tolist()
1257

1258
1259
1260
1261
1262
1263
1264
            # Mask out the sampled tokens that should not be sampled.
            # TODO: Keep in sync with gpu_model_runner.py, in particular
            #       the "else" case here
            for i in discard_sampled_tokens_req_indices:
                valid_sampled_token_ids[i].clear()

            # Append sampled tokens
1265
1266
1267
1268
1269
            for i, req_state, seq_len in request_seq_lens:
                token_id = valid_sampled_token_ids[i][0]
                self.input_batch.token_ids_cpu[i, seq_len] = token_id
                req_state.output_token_ids.append(token_id)
                self.input_batch.num_tokens[i] += 1
1270

1271
1272
1273
1274
        else:
            valid_mask = selected_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
            valid_sampled_token_ids = [
1275
                seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens)
1276
1277
1278
1279
            ]
            self.input_batch.num_tokens[:num_reqs] += gen_lens
            for i, req_state, seq_len in request_seq_lens:
                target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
1280
1281
1282
                self.input_batch.token_ids_cpu[i, target_slice] = (
                    valid_sampled_token_ids[i]
                )
1283
1284
                req_state.output_token_ids.extend(valid_sampled_token_ids[i])

1285
1286
1287
1288
        kv_connector_output = (
            None
            if (finished_sending is None and finished_recving is None)
            else KVConnectorOutput(
1289
1290
1291
                finished_sending=finished_sending,
                finished_recving=finished_recving,
            )
1292
        )
1293

1294
        model_runner_output = ModelRunnerOutput(
1295
            req_ids=req_ids,
1296
            req_id_to_index=self.input_batch.req_id_to_index,
1297
            sampled_token_ids=valid_sampled_token_ids,
1298
            logprobs=logprobs_lists,
1299
            prompt_logprobs_dict=prompt_logprobs_dict,
1300
            pooler_output=[],
1301
1302
            kv_connector_output=kv_connector_output,
        )
1303
1304
1305
1306
1307

        # Check there are no new graphs compiled - all the graphs should be
        # captured and compiled during warm up.
        self._verify_num_xla_graphs("execute_model")

1308
1309
        return model_runner_output

1310
1311
1312
1313
1314
    def update_config(self, overrides: dict[str, Any]) -> None:
        # TODO: TPU config may need extra validation
        # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
        allowed_config_names = {"load_config", "model_config"}
        for config_name, config_overrides in overrides.items():
1315
1316
            assert config_name in allowed_config_names, (
                f"Config `{config_name}` not supported. "
1317
                f"Allowed configs: {allowed_config_names}"
1318
            )
1319
1320
1321
1322
            config = getattr(self, config_name)
            new_config = update_config(config, config_overrides)
            setattr(self, config_name, new_config)

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
    def load_model(self) -> None:
        self.device = self.device_config.device

        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
        xm_tp_rank = xr.global_ordinal()
        with patch(
1337
1338
1339
1340
            "vllm.model_executor.layers.vocab_parallel_embedding."
            "get_tensor_model_parallel_rank",
            return_value=xm_tp_rank,
        ):
1341
1342
1343
            try:
                if self.use_spmd:
                    tpu_loader = TPUModelLoader(
1344
1345
                        load_config=self.vllm_config.load_config
                    )
1346
                    model = tpu_loader.load_model(
1347
                        vllm_config=self.vllm_config,
1348
                        model_config=self.vllm_config.model_config,
1349
1350
                        mesh=self.mesh,
                    )
1351
                else:
1352
                    model_loader = get_model_loader(self.load_config)
1353
1354
                    logger.info("Loading model from scratch...")
                    model = model_loader.load_model(
1355
1356
                        vllm_config=self.vllm_config, model_config=self.model_config
                    )
1357
1358
1359
1360
1361
1362
            except RuntimeError as e:
                raise RuntimeError(
                    f"Unable to load model, a likely reason is the model is "
                    "too large for the current device's HBM memory. "
                    "Consider switching to a smaller model "
                    "or sharding the weights on more chips. "
1363
1364
                    f"See the detailed error: {e}"
                ) from e
1365
        if self.lora_config is not None:
1366
            model = self.load_lora_model(model, self.vllm_config, self.device)
1367
            replace_set_lora(model)
1368

1369
1370
        # Sync all pending XLA execution during model initialization and weight
        # loading.
1371
        torch_xla.sync(wait=False)
1372
        xm.wait_device_ops()
1373
1374
        if not hasattr(self, "model"):
            self.model = model
1375
        self.sampler = TPUSampler()
1376

1377
    def reload_weights(self) -> None:
1378
        assert getattr(self, "model", None) is not None, (
1379
            "Cannot reload weights before model is loaded."
1380
        )
1381
1382
1383
1384
        model_loader = get_model_loader(self.load_config)
        logger.info("Reloading weights inplace...")
        model_loader.load_weights(self.model, model_config=self.model_config)

1385
    @torch.no_grad()
1386
    def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None:
1387
        if self.supports_mm_inputs:
1388
            input_ids = None
1389
1390
1391
            inputs_embeds = torch.zeros(
                (num_tokens, self.hidden_size), dtype=self.dtype, device=self.device
            )
1392
        else:
1393
            input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device)
1394
            inputs_embeds = None
1395
        actual_num_reqs = min(num_tokens, num_reqs)
1396
        position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device)
1397
        padded_num_slices = _get_padded_num_kv_cache_update_slices(
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
            num_tokens, self.max_num_reqs, self.block_size
        )
        num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to(
            self.device
        )
        slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to(
            self.device
        )
        block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to(
            self.device
        )
1409
        query_lens = [1] * num_reqs
1410
1411
1412
1413
1414
        query_start_loc = torch.cumsum(
            torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
        ).to(self.device)
        context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device)
        num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device)
1415
1416
1417
1418
1419
        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            context_lens=context_lens,
            query_start_loc=query_start_loc,
1420
            num_seqs=num_seqs,
1421
            num_kv_update_slices=num_kv_update_slices,
1422
            num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block,
1423
        )
1424

1425
        if self.supports_mm_inputs:
1426
1427
1428
            torch._dynamo.mark_dynamic(inputs_embeds, 0)
        else:
            torch._dynamo.mark_dynamic(input_ids, 0)
1429
1430
        torch._dynamo.mark_dynamic(position_ids, 0)
        torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
1431
1432
1433
        torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
        torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
        torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
1434

1435
        layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys()
1436
        per_layer_attn_metadata = {
1437
            layer_name: attn_metadata for layer_name in layer_names
1438
1439
        }

1440
1441
1442
1443
1444
1445
1446
1447
1448
        with (
            self.maybe_select_dummy_loras(
                self.lora_config, np.array([num_tokens], dtype=np.int32)
            ),
            set_forward_context(per_layer_attn_metadata, self.vllm_config, 0),
        ):
            out = self.model(
                input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds
            )
1449
        self._hidden_states_dtype = out.dtype
1450

1451
1452
1453
    def _set_active_loras(
        self, prompt_lora_mapping, token_lora_mapping, lora_requests
    ) -> None:
1454
        torch_xla.sync(wait=False)  # Captures input updates
1455
1456
1457
        super()._set_active_loras(
            prompt_lora_mapping, token_lora_mapping, lora_requests
        )
1458
        torch_xla.sync(wait=False)  # Captures metadata updates
1459

1460
    def _precompile_mm_encoder(self) -> None:
1461
        if not self.supports_mm_inputs:
1462
1463
            return

1464
1465
        # Pre-compile MM encoder for all supported data modalities.
        hf_config = self.vllm_config.model_config.hf_config
1466
1467
1468
1469
1470
1471
1472

        mm_budget = self.mm_budget
        assert mm_budget is not None

        max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality  # noqa: E501

        for mode, max_items_per_seq in max_items_per_seq_by_modality.items():
1473
            logger.info(
1474
1475
                "Compiling Multimodal %s Encoder with different input shapes.", mode
            )
1476
1477
            start = time.perf_counter()
            # No padding for MM encoder just yet.
1478
            for num_items in range(1, max_items_per_seq + 1):
1479
1480
                logger.info("  -- mode: %s items: %d", mode, num_items)
                batched_dummy_mm_inputs = self._get_mm_dummy_batch(
1481
1482
1483
                    mode,
                    num_items,
                )
1484
                # Run multimodal encoder.
1485
                torch_xla.sync(wait=False)
1486
                mm_embeds = self.model.get_multimodal_embeddings(
1487
1488
                    **batched_dummy_mm_inputs
                )
1489
                torch_xla.sync(wait=False)
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
                num_patches = mm_embeds[0].shape[0]
                items_size = num_patches * num_items

                # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm
                # embeddings are present. We assume `--disable-mm-chunked`,
                # hence only whole items can be scheduled. This implies we just
                # need to compile when `num_items` fit the (padded) `input_ids`
                for num_tokens in self.num_tokens_paddings:
                    if num_tokens >= items_size:
                        # XLA Workaround: if torch.zeros(..device) is used, XLA
                        # compiles a scalar+expansion op, which won't match
                        # the graph generated at runtime. CPU->TPU must be used
1502
1503
1504
                        placeholders_ids = torch.zeros(
                            num_tokens, dtype=torch.int32, device="cpu"
                        )
1505
                        # Align placeholders and actual num mm_embeddings.
1506
                        placeholders_ids[:items_size] = hf_config.image_token_index
1507
1508

                        placeholders_ids = placeholders_ids.to(self.device)
1509
1510
1511
1512

                        mm_mask = torch.tensor([False] * num_tokens)
                        mm_mask[:items_size] = True
                        mm_mask = mm_mask.to(self.device)
1513
                        # Assign outputs or the graph will be cut short.
1514
1515
1516
1517
                        a, b = self._get_model_inputs(
                            placeholders_ids,
                            mm_embed_inputs=([mm_embeds], mm_mask),
                        )
1518
                        assert a is None
1519
                        torch_xla.sync(wait=False)
1520
1521
1522
1523

            # Pre-compile `get_input_embeddings` when mm_embeddings are not
            # present. Chunk is only made of text, no mm_placeholders.
            for num_tokens in self.num_tokens_paddings:
1524
1525
1526
                placeholders_ids = torch.zeros(
                    num_tokens, dtype=torch.int32, device="cpu"
                )
1527
                placeholders_ids = placeholders_ids.to(self.device)
1528
1529
1530
1531
                a, b = self._get_model_inputs(
                    placeholders_ids,
                    mm_embed_inputs=None,
                )
1532
                assert a is None
1533
                torch_xla.sync(wait=False)
1534
1535
1536
1537

            xm.wait_device_ops()
            end = time.perf_counter()
            logger.info(
1538
1539
1540
1541
                "Multimodal %s Encoder compilation finished in in %.2f [secs].",
                mode,
                end - start,
            )
1542

1543
    def _precompile_backbone(self) -> None:
1544
1545
        logger.info("Compiling the model with different input shapes.")
        start = time.perf_counter()
1546
        for num_tokens in self.num_tokens_paddings:
1547
            logger.info("  -- num_tokens: %d", num_tokens)
1548
1549
1550
            self._dummy_run(
                num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req
            )
1551
            if self.most_model_len is not None:
1552
1553
1554
1555
1556
                self._dummy_run(
                    num_tokens,
                    self.num_reqs_most_model_len,
                    self.num_blocks_per_most_len_req,
                )
1557
1558
        xm.wait_device_ops()
        end = time.perf_counter()
1559
        logger.info("Compilation finished in %.2f [secs].", end - start)
1560
        self._update_num_xla_graphs("model backbone")
1561

1562
1563
1564
    def _precompile_select_hidden_states(self) -> None:
        # Compile hidden state selection function for bucketed
        # n_tokens x max_num_reqs. Graph is really small so this is fine.
1565
        logger.info("Compiling select_hidden_states with different input shapes.")
1566
1567
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
1568
        for num_tokens in self.num_tokens_paddings:
1569
1570
1571
            dummy_hidden = torch.zeros(
                (num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype
            )
1572
1573
            torch._dynamo.mark_dynamic(dummy_hidden, 0)
            for num_reqs in self.num_reqs_paddings:
1574
                indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
1575
1576
                torch._dynamo.mark_dynamic(indices, 0)
                self.select_hidden_states(dummy_hidden, indices)
1577
                logger.info("  -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs)
1578
1579
1580
1581
                # Requests can't be more than tokens. But do compile for the
                # next bigger value in case num_tokens uses bucketed padding.
                if num_reqs >= min(num_tokens, self.max_num_reqs):
                    break
1582
        xm.wait_device_ops()
1583
        end = time.perf_counter()
1584
        logger.info("Compilation finished in %.2f [secs].", end - start)
1585
        self._update_num_xla_graphs("select_hidden_states")
1586

1587
1588
    def _precompile_compute_logits(self) -> None:
        logger.info("Compiling compute_logits with different input shapes.")
1589
1590
1591
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
        for num_reqs in self.num_reqs_paddings:
1592
1593
1594
            dummy_hidden = torch.zeros(
                (num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype
            )
1595
1596
1597
1598
1599
1600
1601
1602
1603
            torch._dynamo.mark_dynamic(dummy_hidden, 0)
            self.compute_logits(dummy_hidden)
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("compute_logits")

    def _precompile_structured_decoding(self) -> None:
1604
        logger.info("Compiling structured_decoding with different input shapes.")
1605
1606
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
1607
1608
1609
1610
1611
1612
1613
1614
1615
            dummy_logits = torch.zeros(
                (num_reqs, self.vocab_size),
                device=self.device,
                dtype=self._hidden_states_dtype,
            )
            dummy_require_struct_decoding = self.require_structured_out_cpu[
                :num_reqs
            ].to(self.device)
            dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device)
1616
1617
1618
1619
            # The first dimension of the above 3 dummy tensors cannot be
            # mark_dynamic because some operations in structured_decode require
            # them to be static.
            arange = self.structured_decode_arange.to(self.device)
1620
1621
1622
1623
1624
1625
            self.structured_decode(
                dummy_require_struct_decoding,
                dummy_grammar_bitmask,
                dummy_logits,
                arange,
            )
1626
1627
1628
1629
1630
1631
1632
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("structured_decoding")

    def _precompile_sample_from_logits(self) -> None:
1633
        logger.info("Compiling sample_from_logits with different input shapes.")
1634
1635
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
1636
1637
1638
1639
1640
            dummy_logits = torch.zeros(
                (num_reqs, self.vocab_size),
                device=self.device,
                dtype=self._hidden_states_dtype,
            )
1641
1642
            # The first dimension of dummy_logits cannot be mark_dynamic
            # because some operations in the sampler require it to be static.
1643
1644
            for all_greedy in [False, True]:
                generate_params_if_all_greedy = not all_greedy
1645
1646
1647
1648
1649
1650
                sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
                    self.input_batch,
                    num_reqs,
                    self.device,
                    generate_params_if_all_greedy,
                )
1651
                sampling_metadata.all_greedy = all_greedy
1652
                with self.maybe_select_dummy_loras(
1653
1654
1655
                    self.lora_config, np.array([num_reqs], dtype=np.int32)
                ):
                    self.sample_from_logits_func(dummy_logits, sampling_metadata)
1656
1657
1658
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
1659
1660
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("sample_from_logits")
1661

1662
1663
1664
1665
    def _precompile_gather_logprobs(self) -> None:
        logger.info("Compiling gather_logprobs with different input shapes.")
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
1666
1667
1668
1669
1670
1671
            dummy_logits = torch.zeros(
                (num_reqs, self.vocab_size),
                device=self.device,
                dtype=self._hidden_states_dtype,
            )
            dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device)
1672
            with self.maybe_select_dummy_loras(
1673
1674
                self.lora_config, np.array([num_reqs], dtype=np.int32)
            ):
1675
                self.gather_logprobs(dummy_logits, dummy_tokens)
1676
1677
1678
1679
1680
1681
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("gather_logprobs")

1682
1683
1684
1685
    def capture_model(self) -> None:
        """
        Precompile all the subgraphs with possible input shapes.
        """
1686
1687
1688
1689
1690
1691
1692
1693
        with self.maybe_setup_dummy_loras(self.lora_config):
            self._precompile_mm_encoder()
            self._precompile_backbone()
            self._precompile_select_hidden_states()
            self._precompile_compute_logits()
            self._precompile_structured_decoding()
            self._precompile_sample_from_logits()
            self._precompile_gather_logprobs()
1694

1695
1696
1697
1698
1699
    def profile_run(
        self,
        num_tokens: int,
    ) -> None:
        # Profile with multimodal encoder & encoder cache.
1700
        if self.supports_mm_inputs:
1701
            if self.model_config.multimodal_config.skip_mm_profiling:
1702
                logger.info(
1703
                    "Skipping memory profiling for multimodal encoder and "
1704
1705
                    "encoder cache."
                )
1706
1707
1708
1709
1710
1711
1712
1713
1714
            else:
                mm_budget = self.mm_budget
                assert mm_budget is not None

                # TODO: handle encoder-decoder models once we support them.
                if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
                    # NOTE: Currently model is profiled with a single non-text
                    # modality with the max possible input tokens even when
                    # it supports multiple.
1715
                    dummy_modality = mm_budget.get_modality_with_max_tokens()
1716
1717
1718
                    max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
                        dummy_modality
                    ]
1719
1720
1721
1722
1723
1724
1725
1726
1727

                    logger.info(
                        "Encoder cache will be initialized with a budget of "
                        "%s tokens, and profiled with %s %s items of the "
                        "maximum feature size.",
                        encoder_budget,
                        max_mm_items_per_batch,
                        dummy_modality,
                    )
1728

1729
1730
1731
1732
1733
                    # Create dummy batch of multimodal inputs.
                    batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                        dummy_modality,
                        max_mm_items_per_batch,
                    )
1734

1735
1736
1737
1738
                    # Run multimodal encoder.
                    # Isolate encoder graph from post-processing to minimize
                    # impact of recompilation until it's fixed.
                    start = time.perf_counter()
1739
                    torch_xla.sync(wait=False)
1740
1741
1742
                    dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                        **batched_dummy_mm_inputs
                    )
1743
                    torch_xla.sync(wait=False)
1744
1745
1746
1747
                    xm.wait_device_ops()
                    end = time.perf_counter()
                    logger.info(
                        "Multimodal Encoder profiling finished in %.2f [secs].",
1748
1749
                        end - start,
                    )
1750
1751
1752
1753
1754

                    sanity_check_mm_encoder_outputs(
                        dummy_encoder_outputs,
                        expected_num_items=max_mm_items_per_batch,
                    )
1755

1756
                    # Cache the dummy encoder outputs.
1757
                    self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
1758
1759

        # Trigger compilation for general shape.
1760
1761
1762
        self._dummy_run(
            num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req
        )
1763
        if self.most_model_len is not None:
1764
1765
1766
1767
1768
            self._dummy_run(
                num_tokens,
                self.num_reqs_most_model_len,
                self.num_blocks_per_most_len_req,
            )
1769

1770
        torch_xla.sync(wait=False)
1771
1772
1773
1774
        xm.wait_device_ops()
        self.encoder_cache.clear()
        gc.collect()

1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
    def maybe_setup_cross_layer_kv_sharing(
        self,
        kv_caches: dict[str, torch.Tensor],
        kv_cache_config: KVCacheConfig,
    ) -> None:
        """
        Add layers that re-use KV cache to KV cache group of its target layer.
        Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
        """
        if not self.shared_kv_cache_layers:
            # No cross-layer KV sharing, return
            return

        add_kv_sharing_layers_to_kv_cache_groups(
            self.shared_kv_cache_layers,
            kv_cache_config.kv_cache_groups,
        )

1793
1794
        for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
            logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
1795
1796
            kv_caches[layer_name] = kv_caches[target_layer_name]

1797
1798
1799
1800
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1801
            kv_cache_config: Configuration for the KV cache, including the KV
1802
1803
            cache size of each layer
        """
1804
        if len(kv_cache_config.kv_cache_groups) > 1:
1805
            raise NotImplementedError(
1806
1807
                "Hybrid models with more than one KV cache type are not supported yet."
            )
1808

1809
1810
1811
1812
        if (
            kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
            != self.block_size
        ):
1813
1814
1815
1816
1817
1818
1819
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
                max_model_len=self.max_model_len,
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
1820
1821
1822
                block_sizes=[
                    kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
                ],
1823
1824
1825
                kernel_block_sizes=[
                    kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
                ],
1826
1827
            )
        # Verify dtype compatibility between block_table_cpu and input_batch
1828
1829
1830
1831
        assert (
            self.block_table_cpu.dtype
            == self.input_batch.block_table[0].get_cpu_tensor().dtype
        )
1832

1833
1834
1835
        kv_cache_sizes = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
            assert len(kv_cache_tensor.shared_by) == 1, (
1836
1837
                "KV cache tensor shared by multiple layers is not supported in TPU."
            )
1838
            kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
1839

1840
        kv_caches: dict[str, torch.Tensor] = {}
1841
1842
1843
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
1844
1845
1846
                tensor_size = kv_cache_sizes[layer_name]
                assert tensor_size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_size // kv_cache_spec.page_size_bytes  # noqa
1847
                if isinstance(kv_cache_spec, AttentionSpec):
1848
1849
1850
                    if self.use_spmd:
                        num_kv_heads = kv_cache_spec.num_kv_heads
                        assert self.original_parallel_config is not None
1851
                        tp_size = self.original_parallel_config.tensor_parallel_size
1852
1853
1854
                        # TODO: Handle kv cache duplication under SPMD mode.
                        assert num_kv_heads % tp_size == 0, (
                            f"num_kv_heads {num_kv_heads} must be divisible by "
1855
1856
                            f"tp_size {tp_size} under SPMD mode"
                        )
1857
                    kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
1858
1859
1860
1861
1862
                        num_blocks,
                        kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads,
                        kv_cache_spec.head_size,
                    )
1863
1864
                    dtype = kv_cache_spec.dtype

1865
1866
1867
                    tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to(
                        self.device
                    )
1868

1869
                    kv_caches[layer_name] = tpu_kv_cache
1870
1871
                else:
                    raise NotImplementedError
1872

1873
1874
        # Set up cross-layer KV cache sharing if needed
        self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config)
1875

1876
1877
1878
        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
1879
1880
            self.kv_caches,
        )
1881

1882
1883
1884
        if self.use_spmd:
            # Shard KV Cache
            for cache in self.kv_caches:
1885
                xs.mark_sharding(cache, self.mesh, (None, "x", None, None))
1886

1887
1888
1889
1890
        if has_kv_transfer_group():
            get_kv_transfer_group().register_kv_caches(kv_caches)
            get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)

1891
    def reset_dynamo_cache(self):
1892
1893
1894
1895
        # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs`
        # since the compiled model object of the language backbone of a
        # multimodal model needs to be extracted via `get_language_model`.
        if self.model_config.is_multimodal_model:
1896
            compiled_model = self.model.get_language_model().model
1897
1898
1899
1900
1901
        else:
            compiled_model = self.model.model
        if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
            logger.info("Clear dynamo cache and cached dynamo bytecode.")
            torch._dynamo.eval_frame.remove_from_cache(
1902
1903
                compiled_model.original_code_object
            )
1904
            compiled_model.compiled_codes.clear()
1905

1906
1907
1908
1909
1910
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
    def select_hidden_states(self, hidden_states, indices_do_sample):
        return hidden_states[indices_do_sample]

    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1911
    def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor:
1912
        return self.model.compute_logits(sample_hidden_states)
1913

1914
1915
1916
    # TODO: Under SPMD mode, sample_from_logits has correctness issue.
    #       Re-enable the torch.compile once the issue is fixed in torchxla.
    # @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1917
    def sample_from_logits(
1918
1919
        self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata
    ) -> torch.Tensor:
1920
        """
1921
        Sample with xla-friendly function. This function is to be traced
1922
1923
        separately from `forward` for lighter compilation overhead.
        """
1924
1925
1926
        if sampling_metadata.all_greedy:
            out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
        else:
1927
            out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids
1928
1929
        return out_tokens

1930
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1931
1932
1933
    def gather_logprobs(
        self, logits: torch.Tensor, sampled_tokens: torch.Tensor
    ) -> LogprobsTensors:
1934
1935
1936
1937
1938
1939
1940
1941
1942
        """
        Gather the top_logprobs with corresponding tokens. Use a fixed number
        of logprobs as an alternative to having multiple pre-compiled graphs.
        Select the number of logprobs actually demanded by each request on CPU.
        """
        logprobs = self.sampler.compute_logprobs(logits)
        return self.sampler.gather_logprobs(
            logprobs,
            self.model_config.max_logprobs,
1943
1944
            token_ids=sampled_tokens.squeeze(-1),
        )
1945

1946
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1947
1948
1949
1950
1951
1952
1953
    def structured_decode(
        self,
        require_struct_decoding: torch.Tensor,
        grammar_bitmask: torch.Tensor,
        logits: torch.Tensor,
        arange: torch.Tensor,
    ) -> torch.Tensor:
1954
1955
1956
        return torch.where(
            require_struct_decoding,
            self.apply_grammar_bitmask(logits, grammar_bitmask, arange),
1957
1958
            logits,
        )
1959

1960
1961
1962
1963
    def apply_grammar_bitmask(
        self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor
    ):
        assert logits.shape[0] == grammar_bitmask.shape[0]
1964
1965
        logits_cloned = logits.clone()
        for i in range(logits.shape[0]):
1966
1967
1968
1969
1970
            unpacked_bitmask = (
                torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :])
                & 1
            ) == 0
            unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size]
1971
            logits_cloned[i] = logits_cloned[i].masked_fill(
1972
1973
                unpacked_bitmask, -float("inf")
            )
1974
1975
        return logits_cloned

1976
1977
    def get_multimodal_embeddings(self, *args, **kwargs):
        return self.model.get_multimodal_embeddings(*args, **kwargs)
1978

1979
1980
1981
    def get_input_embeddings(self, *args, **kwargs):
        return self.model.get_input_embeddings(*args, **kwargs)

1982
    def prepare_structured_decoding_input(
1983
        self, logits: torch.Tensor, grammar_output: "GrammarOutput"
1984
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1985
        grammar_bitmask = grammar_output.grammar_bitmask
1986
1987
1988
1989
1990
1991
        num_reqs, _ = logits.shape

        # Reset pre-allocated tensors
        self.grammar_bitmask_cpu.zero_()
        self.require_structured_out_cpu.zero_()

1992
        cumulative_mask_idx = 0
1993
        for req_id in grammar_output.structured_output_request_ids:
1994
            if req_id not in self.input_batch.req_id_to_index:
1995
1996
                continue
            batch_index = self.input_batch.req_id_to_index[req_id]
1997
            self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
1998
1999
                grammar_bitmask[cumulative_mask_idx]
            )
2000
2001
2002
2003
2004
2005
            # It's not guaranteed that all requests in this batch require
            # structured output, so create a bool tensor to represent
            # the requests that need structured output.
            self.require_structured_out_cpu[batch_index] = True
            cumulative_mask_idx += 1

2006
2007
2008
2009
2010
        return (
            self.require_structured_out_cpu[:num_reqs].to(logits.device),
            self.grammar_bitmask_cpu[:num_reqs].to(logits.device),
            self.structured_decode_arange.to(logits.device),
        )
2011

2012
2013
2014
2015
2016
2017
    def _get_mm_dummy_batch(
        self,
        modality: str,
        max_items_per_batch: int,
    ) -> BatchedTensorInputs:
        """Dummy data for profiling and precompiling multimodal models."""
2018
2019
        assert self.mm_budget is not None

2020
        dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
2021
            model_config=self.model_config,
2022
            seq_len=self.max_model_len,
2023
            mm_counts={modality: 1},
2024
            cache=self.mm_budget.cache,
2025
        )
2026
2027
2028
        dummy_mm_data = dummy_decoder_data.multi_modal_data

        # Result in the maximum GPU consumption of the model
2029
2030
        dummy_mm_item = dummy_mm_data[modality][0]
        dummy_mm_items = [dummy_mm_item] * max_items_per_batch
2031

2032
        model = cast(SupportsMultiModal, self.model)
2033
2034
2035
2036
2037
2038
2039
2040
2041
        return next(
            grouped_mm_kwargs
            for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
                dummy_mm_items,
                device=self.device,
                pin_memory=self.pin_memory,
                merge_by_field_config=model.merge_by_field_config,
            )
        )
2042

2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054

def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
    logger.info("Preparing request paddings:")
    # assert min_req_size is power of 2
    assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
    paddings: list = []
    num = max(MIN_NUM_SEQS, min_req_size)
    while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
        paddings.append(num)
        logger.info("    %d", num)
        num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
    return paddings
2055
2056


2057
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
2058
    res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
2059
    return min(res, upper_limit)
2060
2061


2062
2063
2064
2065
def _get_token_paddings(
    min_token_size: int, max_token_size: int, padding_gap: int
) -> list[int]:
    """Generate a list of padding size, starting from min_token_size,
2066
    ending with a number that can cover max_token_size
2067

2068
2069
2070
    If padding_gap == 0 then:
        increase 2X each time (exponential)
    else:
2071
        first increase the size to twice,
2072
        then increase the padding size by padding_gap.
2073
    """
2074
2075
    # assert min_token_size is power of 2
    assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
2076
2077
    paddings = []
    num = min_token_size
2078
2079

    if padding_gap == 0:
2080
        logger.info("Using exponential token paddings:")
2081
        while True:
2082
2083
            logger.info("    %d", num)
            paddings.append(num)
2084
2085
            if num >= max_token_size:
                break
2086
2087
            num *= 2
    else:
2088
        logger.info("Using incremental token paddings:")
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
        while num <= padding_gap:
            logger.info("    %d", num)
            paddings.append(num)
            num *= 2
        num //= 2
        while num < max_token_size:
            num += padding_gap
            logger.info("    %d", num)
            paddings.append(num)

2099
2100
2101
2102
    return paddings


def _get_padded_token_len(paddings: list[int], x: int) -> int:
2103
    """Return the first element in paddings list greater or equal to x."""
2104
2105
2106
    index = bisect.bisect_left(paddings, x)
    assert index < len(paddings)
    return paddings[index]
2107
2108


2109
2110
2111
def _get_padded_num_kv_cache_update_slices(
    num_tokens: int, max_num_reqs: int, page_size: int
) -> int:
2112
2113
    """Calculates the padded number of KV cache update slices to avoid
    recompilation."""
2114
2115
2116
2117
2118
    # NOTE(chengjiyao): let's say R_i is the token num for i-th request,
    # so it occupies most 2 + R_i // page_size pages. The total maximum
    # possible number of pages needed is sum(2 + R_i // page_size), which
    # is <= 2 * max_num_reqs + sum(R_i) // page_size
    # = 2 * max_num_reqs + num_tokens // page_size
2119
2120
2121
2122
2123
    padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
    padded_num_slices = min(padded_num_slices, num_tokens)
    return padded_num_slices


2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
    """Find the optimum number of slices to copy per Pallas program instance.

    Increasing the number of slices copied in one instance of the kernel program
    will increase HBM bandwidth utilization via more in-flight DMAs.

    However, it will also use more VMEM, and experimentally, we observed
    performance regression at 128 slices on v6e, likely due to running
    out of scalar registers. Thus this function will limit the number of
    slices to 64.
    """
2135
2136
2137
    # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
    # calculate num_slices_per_block based on 16MB in case any register spills.
    vmem_limit = 16 * 1024 * 1024
2138
2139
2140
2141
2142
2143
2144
2145
    num_slices_per_block = vmem_limit // page_size_bytes
    assert num_slices_per_block > 0, "Number of slices should be positive"
    num_slices_per_block = prev_power_of_2(num_slices_per_block)
    if num_slices_per_block > 64:
        num_slices_per_block = 64
    return num_slices_per_block


2146
2147
2148
2149
2150
2151
def replace_set_lora(model):
    def _tpu_set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
2152
        embeddings_tensor: torch.Tensor | None,
2153
2154
2155
2156
    ):
        # TODO: The integer index leads to a recompilation, but converting it
        # to a tensor doesn't seem to work anymore. This might be fixed with a
        # later release of torch_xla.
2157
        self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
2158
        torch_xla.sync(wait=False)
2159
2160
2161

    def _tpu_reset_lora(self, index: int):
        self._original_reset_lora(index)
2162
        torch_xla.sync(wait=False)
2163
2164
2165
2166
2167
2168

    for _, module in model.named_modules():
        if isinstance(module, BaseLayerWithLoRA):
            module._original_set_lora = module.set_lora
            module._original_reset_lora = module.reset_lora
            module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
2169
            module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__)