tpu_model_runner.py 90.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import bisect
4
import gc
5
import time
6
from typing import TYPE_CHECKING, Any, cast
7
8
9
10
11
from unittest.mock import patch

import numpy as np
import torch
import torch.nn as nn
12

13
# TPU XLA related
14
import torch_xla
15
import torch_xla.core.xla_model as xm
16
import torch_xla.distributed.spmd as xs
17
18
import torch_xla.runtime as xr

19
import vllm.envs as envs
20
from vllm.attention import Attention
21
from vllm.attention.backends.abstract import AttentionType
22
from vllm.attention.layer import MLAAttention
23
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
24
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
25
26
27
28
29
30
31
from vllm.config import (
    ParallelConfig,
    VllmConfig,
    get_layers_from_vllm_config,
    update_config,
)
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
32
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
33
from vllm.forward_context import set_forward_context
34
from vllm.logger import init_logger
35
from vllm.lora.layers import BaseLayerWithLoRA
36
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
37
from vllm.model_executor.model_loader import get_model_loader
38
from vllm.model_executor.model_loader.tpu import TPUModelLoader
39
40
41
42
from vllm.model_executor.models.interfaces import (
    SupportsMultiModal,
    supports_transcription,
)
43
from vllm.model_executor.models.interfaces_base import (
44
45
46
    is_pooling_model,
    is_text_generation_model,
)
47
from vllm.multimodal import MULTIMODAL_REGISTRY
48
49
50
51
52
from vllm.multimodal.inputs import (
    BatchedTensorInputs,
    MultiModalKwargsItem,
    PlaceholderRange,
)
53
from vllm.multimodal.utils import group_mm_kwargs_by_modality
54
from vllm.sequence import IntermediateTensors
55
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
56
from vllm.utils.math_utils import cdiv, prev_power_of_2
57
from vllm.utils.platform_utils import is_pin_memory_available
58
59
60
61
62
63
64
65
66
67
68
from vllm.v1.attention.backends.pallas import (
    TPU_STR_DTYPE_TO_TORCH_DTYPE,
    PallasAttentionBackend,
    PallasMetadata,
    get_page_size_bytes,
)
from vllm.v1.kv_cache_interface import (
    AttentionSpec,
    FullAttentionSpec,
    KVCacheConfig,
    KVCacheSpec,
69
    MLAAttentionSpec,
70
71
72
73
74
75
76
77
    SlidingWindowSpec,
)
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    LogprobsLists,
    LogprobsTensors,
    ModelRunnerOutput,
)
78
79
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
80
from vllm.v1.worker.kv_connector_model_runner_mixin import (
81
82
83
    KVConnectorModelRunnerMixin,
    KVConnectorOutput,
)
84
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
85
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
86

87
88
89
90
91
92
from .utils import (
    MultiModalBudget,
    add_kv_sharing_layers_to_kv_cache_groups,
    bind_kv_cache,
    sanity_check_mm_encoder_outputs,
)
93

94
if TYPE_CHECKING:
95
    from vllm.v1.core.sched.output import SchedulerOutput
96
97
98

logger = init_logger(__name__)

99
INVALID_TOKEN_ID = -1
100
101
# Smallest output size
MIN_NUM_SEQS = 8
102
103


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#########################################################
# Ways to avoid recompilation
#########################################################
#
# The model executor has two primary components:
# 1. preparing the model and sampler inputs
# 2. executing the model and sampler.
# The core idea is to avoid any TPU computation during input preparation. For
# better compilation tracking and increased flexibility, the model execution and
# sampler are divided into several distinct components.
#
# Below are the detailed steps:
#
# Step 1
# It is recommended to avoid TPU operations when preparing the model and sampler
# inputs. CPU tensors can be prepared and transferred to the XLA device using
# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
# compilation.
#
# Step 2
# The TPU execution should be decomposed into subgraphs (4 at the moment):
# 1. the main model
# 2. selecting hidden states for each request
# 3. sampler
# 4. encoder.
# Each subgraph should be decorated in a torch.compile. This is used to make
# sure that we have the same subgraph topology in both dummy_run and
# xecute_model. The results from these subgraphs should either be passed to
# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
# subsequent processing on the CPU.
#
# Step 3
# The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate
# pre-compilation.
139
class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
140
141
142
143
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
144
        original_parallel_config: ParallelConfig | None = None,
145
146
147
148
149
150
151
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
152
        self.original_parallel_config = original_parallel_config
153
154
155
156
157
158
159
160
161
162
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
        self.device_config = vllm_config.device_config

        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
        self.device = device
163
        self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
164

165
166
167
168
169
170
        # SPMD Related
        self.use_spmd = envs.VLLM_XLA_USE_SPMD
        if self.use_spmd:
            num_devices = xr.global_runtime_device_count()
            mesh_shape = (num_devices, 1)
            device_ids = np.array(range(num_devices))
171
            self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
172

173
        self.enforce_eager = model_config.enforce_eager
174
175
176
177

        self.num_xla_graphs = 0
        self._update_num_xla_graphs("init")

178
179
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
180
        if cache_config.cache_dtype == "auto":
181
182
            model_dtype = self.dtype
            if isinstance(model_dtype, str):
183
                self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
184
185
            else:
                self.kv_cache_dtype = model_dtype
186
        else:
187
            self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
188
        self._hidden_states_dtype = self.dtype
189
190
191
192

        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
193
        self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
194
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
195
196
197
198
199
        self.num_blocks_per_most_len_req = (
            cdiv(self.most_model_len, self.block_size)
            if self.most_model_len is not None
            else None
        )
200
201
        # InputBatch needs to work with sampling tensors greater than padding
        # to avoid dynamic shapes. Also, avoid suboptimal alignment.
202
        self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
203
204
205
        self.num_tokens_paddings = _get_token_paddings(
            min_token_size=16,
            max_token_size=scheduler_config.max_num_batched_tokens,
206
207
            padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP,
        )
208
209
210
        # In case `max_num_tokens < max(num_tokens_paddings)` use the actual
        # padded max value to pre-allocate data structures and pre-compile.
        self.max_num_tokens = self.num_tokens_paddings[-1]
211
212
213

        # Model-related.
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
214
            parallel_config, "attention"
215
216
        )
        self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
217
218
219
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
        self.hidden_size = model_config.get_hidden_size()
220
        self.vocab_size = model_config.get_vocab_size()
221

222
223
224
        if self.lora_config is not None:
            self.vocab_size += self.lora_config.lora_extra_vocab_size

225
226
227
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.uses_mrope = model_config.uses_mrope
228
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
229
230
            model_config
        )
231
232
233
        # TODO: Support M-RoPE (e.g, Qwen2-VL)
        assert not self.uses_mrope, "TPU does not support M-RoPE yet."

234
235
236
237
238
239
240
241
242
243
        self._num_slices_per_kv_cache_update_block = (
            _get_num_slices_per_kv_cache_update_block(
                get_page_size_bytes(
                    block_size=self.block_size,
                    num_kv_heads=self.num_kv_heads,
                    head_size=self.head_size,
                    kv_cache_dtype=self.kv_cache_dtype,
                )
            )
        )
244

245
        # Lazy initialization
246
        self.model: nn.Module  # Set after load_model
247
        self.kv_caches: list[torch.Tensor] = []
248
249
        # mm_hash -> encoder_output
        self.encoder_cache: dict[str, torch.Tensor] = {}
250
251
252

        # Request states.
        self.requests: dict[str, CachedRequestState] = {}
253

254
255
256
257
258
259
260
261
        # Initialize input batch early to avoid AttributeError in _update_states
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
            vocab_size=self.model_config.get_vocab_size(),
262
            block_sizes=[self.block_size],
263
            kernel_block_sizes=[self.cache_config.block_size],
264
265
        )

266
267
268
        # Cached torch/numpy tensor
        # The pytorch tensor and numpy array share the same buffer.
        # Sometimes the numpy op is faster so we create both.
269
270
271
        self.input_ids_cpu = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device="cpu"
        )
272

273
274
275
        self.positions_cpu = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device="cpu"
        )
276
277
        self.positions_np = self.positions_cpu.numpy()
        self.block_table_cpu = torch.zeros(
278
            (self.max_num_reqs, self.max_num_blocks_per_req),
279
            dtype=torch.int32,
280
281
            device="cpu",
        )
282
        # adjust num_reqs to avoid SMEM OOM.
283
284
285
286
287
288
289
290
291
292
        self.num_reqs_most_model_len = (
            min(
                PallasAttentionBackend.get_max_num_seqs(
                    self.most_model_len, self.block_size
                ),
                self.max_num_reqs,
            )
            if self.most_model_len is not None
            else None
        )
293
        self.num_reqs_max_model_len = min(
294
295
296
297
298
299
300
301
302
303
304
            PallasAttentionBackend.get_max_num_seqs(
                self.max_model_len, self.block_size
            ),
            self.max_num_reqs,
        )
        self.query_start_loc_cpu = torch.zeros(
            self.max_num_tokens + 1,
            dtype=torch.int32,
            device="cpu",
            pin_memory=self.pin_memory,
        )
305
306
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()

307
308
309
310
311
312
        self.seq_lens_cpu = torch.zeros(
            self.max_num_tokens,
            dtype=torch.int32,
            device="cpu",
            pin_memory=self.pin_memory,
        )
313
        self.seq_lens_np = self.seq_lens_cpu.numpy()
314

315
316
        # Only relevant for multimodal models
        if self.supports_mm_inputs:
317
318
319
320
321
322
            self.is_mm_embed_cpu = torch.zeros(
                self.max_num_tokens,
                dtype=torch.bool,
                device="cpu",
                pin_memory=self.pin_memory,
            )
323

324
325
        # Range tensor with values [0 .. self.max_num_tokens - 1].
        # Used to initialize positions / context_lens / seq_lens
326
327
        # Keep in int64 to avoid overflow with long context
        self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64)
328
        self.num_reqs_paddings = _get_req_paddings(
329
330
            min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs
        )
331

332
333
334
335
336
337
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}

338
339
340
341
342
        # tensors for structured decoding
        self.grammar_bitmask_cpu = torch.zeros(
            (self.max_num_reqs, cdiv(self.vocab_size, 32)),
            dtype=torch.int32,
            device="cpu",
343
344
            pin_memory=self.pin_memory,
        )
345
346
347
348
        self.require_structured_out_cpu = torch.zeros(
            (self.max_num_reqs, 1),
            dtype=torch.bool,
            device="cpu",
349
350
            pin_memory=self.pin_memory,
        )
351
        self.structured_decode_arange = torch.arange(
352
353
            0, 32, device="cpu", pin_memory=self.pin_memory
        )
354

355
356
357
358
359
360
361
362
363
        self.mm_budget = (
            MultiModalBudget(
                self.model_config,
                self.scheduler_config,
                self.mm_registry,
            )
            if self.supports_mm_inputs
            else None
        )
364

365
366
367
368
369
        if not self.use_spmd:
            self.sample_from_logits_func = torch.compile(
                self.sample_from_logits,
                backend="openxla",
                fullgraph=True,
370
371
                dynamic=False,
            )
372
373
374
        else:
            self.sample_from_logits_func = self.sample_from_logits

375
376
377
378
    def reset_mm_cache(self) -> None:
        if self.mm_budget:
            self.mm_budget.reset_cache()

379
380
381
382
383
384
385
386
387
388
    def _update_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        total_cached_graphs = xr.get_num_cached_compilation_graph()
        new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
        if new_compiled_graphs == 0:
            return

389
390
391
        logger.info(
            "Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str
        )
392
393
394
395
396
397
398
399
400
401
402
        self.num_xla_graphs += new_compiled_graphs

    def _verify_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        curr_cached_graph = xr.get_num_cached_compilation_graph()
        assert self.num_xla_graphs == curr_cached_graph, (
            "Recompilation after warm up is detected during {}."
            " num_xla_graphs = {} curr_cached_graph = {}".format(
403
404
405
                case_str, self.num_xla_graphs, curr_cached_graph
            )
        )
406

407
408
409
410
411
412
413
414
    def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        Returns:
415
            True if there is a new/resumed/paused/finished request.
416
417
418
419
420
421
422
423
424
425
426
427
            If False, we can skip copying SamplingMetadata to the GPU.
        """
        # Remove finished requests from the cached states.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)

        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
428
        removed_req_indices: list[int] = []
429
430
431
432
433
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)

434
        # Free the cached encoder outputs.
435
436
        for mm_hash in scheduler_output.free_encoder_mm_hashes:
            self.encoder_cache.pop(mm_hash, None)
437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            assert req_index is not None
            removed_req_indices.append(req_index)

455
        req_ids_to_add: list[str] = []
456
457
        # Add new requests to the cached states.
        for new_req_data in scheduler_output.scheduled_new_reqs:
458
            assert new_req_data.sampling_params is not None, (
459
                "Pooling is not supported in TPU yet"
460
            )
461
462
463
464
465
466
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params

            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
                prompt_token_ids=new_req_data.prompt_token_ids,
467
                prompt_embeds=new_req_data.prompt_embeds,
468
                mm_features=new_req_data.mm_features,
469
                sampling_params=sampling_params,
470
                pooling_params=None,
471
                generator=None,
472
473
474
475
476
477
478
479
480
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                output_token_ids=[],
                lora_request=new_req_data.lora_request,
            )

            req_ids_to_add.append(req_id)

        # Update the states of the running/resumed requests.
481
482
        req_data = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(req_data.req_ids):
483
            req_state = self.requests[req_id]
484
485
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
486
            resumed_from_preemption = req_id in req_data.resumed_req_ids
487
488

            # Update the cached states.
489
490
            req_state.num_computed_tokens = num_computed_tokens
            if not resumed_from_preemption:
491
492
                if new_block_ids is not None:
                    # Append the new blocks to the existing block IDs.
493
                    for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
494
                        block_ids.extend(new_ids)
495
            else:
496
                assert new_block_ids is not None
497
498
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
499
                req_state.block_ids = new_block_ids
500
501
502
503
504
505
506
507
508
509

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
510
            self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
511
            if new_block_ids is not None:
512
                self.input_batch.block_table.append_row(new_block_ids, req_index)
513
514
515
516
517
518

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
519
520
            # Fill the empty index or append to the end
            req_index = removed_req_indices.pop() if removed_req_indices else None
521
522
523
524
525
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
526

527
528
529
530
531
        return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0

    def get_model(self) -> nn.Module:
        return self.model

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
    def get_supported_generation_tasks(self) -> list[GenerationTask]:
        model = self.get_model()
        supported_tasks = list[GenerationTask]()

        if is_text_generation_model(model):
            supported_tasks.append("generate")

        if supports_transcription(model):
            if model.supports_transcription_only:
                return ["transcription"]

            supported_tasks.append("transcription")

        return supported_tasks

547
548
549
550
551
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        model = self.get_model()
        if not is_pooling_model(model):
            return []

552
        return list(model.pooler.get_supported_tasks())
553

554
555
556
557
558
559
560
561
562
563
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks = list[SupportedTask]()

        if self.model_config.runner_type == "generate":
            tasks.extend(self.get_supported_generation_tasks())
        if self.model_config.runner_type == "pooling":
            tasks.extend(self.get_supported_pooling_tasks())

        return tuple(tasks)

564
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
565
        """
566
        Generates the KVCacheSpec by parsing the kv cache format from each
567
568
        Attention module in the static forward context.
        Returns:
569
            KVCacheSpec: A dictionary mapping layer names to their KV cache
570
571
572
            format. Layers that do not need KV cache are not included.
        """

573
        layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
574
        block_size = self.vllm_config.cache_config.block_size
575
576
        cache_dtype_str = self.vllm_config.cache_config.cache_dtype

577
        kv_cache_spec: dict[str, KVCacheSpec] = {}
578
        for layer_name, attn_module in layers.items():
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
            # Classic Attention path
            if isinstance(attn_module, Attention):
                if (
                    kv_tgt_layer := attn_module.kv_sharing_target_layer_name
                ) is not None:
                    # The layer doesn't need its own KV cache and will use that of
                    # the target layer. We skip creating a KVCacheSpec for it, so
                    # that KV cache management logic will act as this layer does
                    # not exist, and doesn't allocate KV cache for the layer. This
                    # enables the memory saving of cross-layer kv sharing, allowing
                    # a given amount of memory to accommodate longer context lengths
                    # or enable more requests to be processed simultaneously.
                    self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                    continue

                if attn_module.attn_type == AttentionType.DECODER:
                    if isinstance(attn_module, ChunkedLocalAttention):
                        logger.warning_once(
                            "Using irope in Pallas is not supported yet, it "
                            "will fall back to global attention for long context."
                        )
                    if attn_module.sliding_window is not None:
                        kv_cache_spec[layer_name] = SlidingWindowSpec(
                            block_size=block_size,
                            num_kv_heads=attn_module.num_kv_heads,
                            head_size=attn_module.head_size,
                            dtype=self.kv_cache_dtype,
                            sliding_window=attn_module.sliding_window,
                        )
                    else:
                        kv_cache_spec[layer_name] = FullAttentionSpec(
                            block_size=block_size,
                            num_kv_heads=attn_module.num_kv_heads,
                            head_size=attn_module.head_size,
                            dtype=self.kv_cache_dtype,
                        )
                elif attn_module.attn_type in (
                    AttentionType.ENCODER,
                    AttentionType.ENCODER_ONLY,
                ):
                    # encoder-only attention does not need KV cache.
                    continue
                elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                    raise NotImplementedError
623
                else:
624
625
626
627
628
629
630
631
632
633
634
635
                    raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
            # MLAAttention path
            elif isinstance(attn_module, MLAAttention):
                if layer_name in kv_cache_spec:
                    continue
                kv_cache_spec[layer_name] = MLAAttentionSpec(
                    block_size=block_size,
                    num_kv_heads=1,
                    head_size=attn_module.head_size,
                    dtype=self.kv_cache_dtype,
                    cache_dtype_str=cache_dtype_str,
                )
636
            else:
637
                continue
638
639
640

        return kv_cache_spec

641
642
643
    def _get_slot_mapping_metadata(
        self, num_reqs, num_scheduled_tokens_per_req
    ) -> np.ndarray:
644
645
646
647
648
649
650
651
652
653
654
655
        """
        Computes metadata for mapping slots to blocks in the key-value (KV)
        cache for a batch of requests.

        This function determines, for each request in the batch, how the
        scheduled tokens are distributed across memory blocks, and generates
        metadata needed to map slices of tokens to their corresponding positions
        in the KV cache.

        Args:
            num_reqs (int): Number of requests in the current batch.
            num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
656
                to be scheduled for each request.
657
658
659

        Returns:
            np.ndarray: A 2D array of shape (total_block_len, 3), where each row
660
                contains:
661
                - kv_cache_start_index (int): The starting index in the KV cache
662
                  for the corresponding slice.
663
                - new_kv_start_index (int): The starting index in the new KV
664
                  cache for the corresponding slice.
665
666
667
                - slice_len (int): The length of the slice.
        """
        slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
668
669
670
671
        slices_end = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs]
            + num_scheduled_tokens_per_req
        )
672
673
674
675
        local_block_start_idx = slices_start // self.block_size
        local_block_end_idx = (slices_end - 1) // self.block_size
        no_repeat_req_indices = self.arange_np[:num_reqs]
        global_block_start_idx = (
676
677
            no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx
        )
678
679
680
681
682
683
684
        block_lens = local_block_end_idx - local_block_start_idx + 1
        global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
        slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
        global_block_indices = global_block_start_idx + slice_arange
        block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
        total_block_len = np.sum(block_lens)
685
686
687
        slot_mapping_slices = np.repeat(
            np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0
        )
688
689
690
        cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
        np.cumsum(block_lens, out=cu_block_lens[1:])
        for req_idx in range(num_reqs):
691
692
693
694
695
696
            slot_mapping_slices[cu_block_lens[req_idx]][0] = (
                slices_start[req_idx] % self.block_size
            )
            slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = (
                slices_end[req_idx] - 1
            ) % self.block_size + 1
697
698
699
        slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
        cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
        np.cumsum(slice_lens, out=cu_slices_lens[1:])
700
701
702
        kv_cache_start_indices = slot_mapping_slices[:, 0] + (
            block_numbers * self.block_size
        )
703
704
        new_kv_start_indices = cu_slices_lens[:-1]
        slot_mapping_metadata = np.stack(
705
706
            [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1
        )
707
708
        return slot_mapping_metadata

709
    def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int):
710
        assert scheduler_output.total_num_scheduled_tokens > 0
711
712
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0
713
        assert start_index < num_reqs
714

715
        # Get the number of scheduled tokens for each request.
716
        use_max_model_len = self.most_model_len is None
717
718
        num_scheduled_tokens_per_req = []
        max_num_scheduled_tokens_all_reqs = 0
719
720
721
722
723
        end_index = start_index

        # Use either most_model_len or max_model_len depending on request size.
        for i in range(start_index, num_reqs):
            req_id = self.input_batch.req_ids[i]
724
            assert req_id is not None
725
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
726
727
            if not use_max_model_len and num_tokens > self.most_model_len:
                use_max_model_len = True
728
            num_scheduled_tokens_per_req.append(num_tokens)
729
730
        if use_max_model_len:
            if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
731
732
733
                num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
                    : self.num_reqs_max_model_len
                ]
734
735
736
737
                end_index = start_index + self.num_reqs_max_model_len
            else:
                end_index = num_reqs
        else:
738
739
740
741
            if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len:
                num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
                    : self.num_reqs_most_model_len
                ]
742
743
744
745
                end_index = start_index + self.num_reqs_most_model_len
            else:
                end_index = num_reqs
        max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
746
747
748
        num_scheduled_tokens_per_req = np.array(
            num_scheduled_tokens_per_req, dtype=np.int32
        )
749
        total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
750
751
        assert max_num_scheduled_tokens_all_reqs > 0

752
753
        num_reqs = len(num_scheduled_tokens_per_req)

754
755
756
        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        # For each scheduled token, what are the corresponding req index.
757
        req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req)
758
759
760
761
762

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # For each scheduled token, what is its position in corresponding req.
        arange = np.concatenate(
763
764
            [self.arange_np[:n] for n in num_scheduled_tokens_per_req]
        )
765
766
767

        # Get positions.
        positions_np = self.positions_np[:total_num_scheduled_tokens]
768
769
770
771
772
        np.add(
            self.input_batch.num_computed_tokens_cpu[req_indices],
            arange,
            out=positions_np,
        )
773
774
775
776
777

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
778
779
780
        token_indices = (
            positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
        )
781
782
783
784

        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
785
786
787
788
789
790
        torch.index_select(
            self.input_batch.token_ids_cpu_tensor.flatten(),
            0,
            torch.from_numpy(token_indices),
            out=self.input_ids_cpu[:total_num_scheduled_tokens],
        )
791
792
793

        # Prepare the attention metadata.
        self.query_start_loc_np[0] = 0
794
795
796
797
        np.cumsum(
            num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1]
        )
        self.query_start_loc_np[num_reqs + 1 :] = 1
798
799

        self.seq_lens_np[:num_reqs] = (
800
801
802
            self.input_batch.num_computed_tokens_cpu[:num_reqs]
            + num_scheduled_tokens_per_req
        )
803
804

        # Do the padding and copy the tensors to the TPU.
805
        padded_total_num_scheduled_tokens = _get_padded_token_len(
806
807
            self.num_tokens_paddings, total_num_scheduled_tokens
        )
808
809
        # Zero out to avoid spurious values from prev iteration (last cp chunk)
        self.input_ids_cpu[
810
811
812
813
814
815
816
817
            total_num_scheduled_tokens:padded_total_num_scheduled_tokens
        ] = 0
        self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to(
            self.device
        )
        self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(
            self.device
        )
818
        if use_max_model_len:
819
820
821
822
823
824
825
826
827
828
            block_tables = self.block_table_cpu[
                : self.num_reqs_max_model_len, : self.max_num_blocks_per_req
            ]
            block_tables[:num_reqs, : self.max_num_blocks_per_req] = (
                self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]
            )
            query_start_loc = self.query_start_loc_cpu[
                : self.num_reqs_max_model_len + 1
            ].to(self.device)
            seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device)
829
        else:
830
831
832
833
834
835
836
837
838
839
840
841
            block_tables = self.block_table_cpu[
                : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req
            ]
            block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = (
                self.input_batch.block_table[0].get_cpu_tensor()[
                    :num_reqs, : self.num_blocks_per_most_len_req
                ]
            )
            query_start_loc = self.query_start_loc_cpu[
                : self.num_reqs_most_model_len + 1
            ].to(self.device)
            seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device)
842
        block_tables = block_tables.to(self.device)
843

844
        # Calculate the slot mapping
845
        slot_mapping_metadata = self._get_slot_mapping_metadata(
846
847
            num_reqs, num_scheduled_tokens_per_req
        )
848
        num_kv_update_slices = slot_mapping_metadata.shape[0]
849
        padded_num_slices = _get_padded_num_kv_cache_update_slices(
850
851
            padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size
        )
852
853
854
        slot_mapping_metadata = np.pad(
            slot_mapping_metadata,
            [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
855
856
            constant_values=0,
        )
857
        slot_mapping_metadata = np.transpose(slot_mapping_metadata)
858
        slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device)
859

860
861
862
863
864
        if self.lora_config is not None:
            # We need to respect padding when activating LoRA adapters
            padded_num_scheduled_tokens_per_req = np.copy(
                num_scheduled_tokens_per_req
            )  # Copying to avoid accidental state corruption bugs
865
            padded_num_scheduled_tokens_per_req[-1] += (
866
                padded_total_num_scheduled_tokens - total_num_scheduled_tokens
867
            )
868

869
            self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req)
870

871
        attn_metadata = PallasMetadata(
872
            slot_mapping=slot_mapping_metadata,
873
            block_tables=block_tables,
874
875
            context_lens=seq_lens,
            query_start_loc=query_start_loc,
876
877
878
879
880
            num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device),
            num_kv_update_slices=torch.tensor(
                [num_kv_update_slices], dtype=torch.int32, device=self.device
            ),
            num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block,
881
        )
882
883
884
885
886
        # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
        # request in the batch. While we should not sample any token from this
        # partial request, we do so for simplicity. We will ignore the sampled
        # token from the partial request.
        # TODO: Support prompt logprobs.
887
        padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
888
889
            num_reqs, self.max_num_reqs
        )
890
891
        # Indices at which we sample (positions of last token in the sequence).
        # Padded to avoid recompiling when `num_reqs` varies.
892
        logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1
893
        logits_indices = logits_indices.to(self.device)
894

895
896
897
898
899
        if self.lora_config is not None:
            # We need to respect padding when activating LoRA adapters
            padded_num_scheduled_tokens_per_req = np.copy(
                num_scheduled_tokens_per_req
            )  # Copying to avoid accidental state corruption bugs
900
            padded_num_scheduled_tokens_per_req[-1] += (
901
                padded_total_num_scheduled_tokens - total_num_scheduled_tokens
902
            )
903

904
            self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req)
905

906
        layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys()
907
        per_layer_attn_metadata = {
908
            layer_name: attn_metadata for layer_name in layer_names
909
        }
910
911
912
913
914
915
916
        return (
            per_layer_attn_metadata,
            logits_indices,
            padded_num_reqs,
            num_reqs,
            end_index,
        )
917

918
    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
919
920
921
922
923
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
924
        mm_kwargs = list[MultiModalKwargsItem]()
925
926
        # List of tuple (mm_hash, pos_info)
        mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
927
928
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
929
930

            for mm_input_id in encoder_input_ids:
931
932
933
934
                mm_feature = req_state.mm_features[mm_input_id]
                mm_hash = mm_feature.identifier
                mm_kwargs.append(mm_feature.data)
                mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
935
936
937
938
939
940
941
942

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
943
        model = cast(SupportsMultiModal, self.model)
944
        encoder_outputs = []
945
        for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
946
947
948
949
            mm_kwargs,
            device=self.device,
            pin_memory=self.pin_memory,
            merge_by_field_config=model.merge_by_field_config,
950
        ):
951
952
953
954
955
956
957
            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
958
            torch_xla.sync(wait=False)
959
            curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group)
960
            torch_xla.sync(wait=False)
961

962
963
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
964
                expected_num_items=num_items,
965
966
            )

967
968
969
970
971
972
            if isinstance(curr_group_outputs, torch.Tensor):
                encoder_outputs.append(curr_group_outputs)
            else:
                assert isinstance(curr_group_outputs, (list, tuple))
                for output in curr_group_outputs:
                    encoder_outputs.append(output)
973
974

        # Cache the encoder outputs.
975
976
977
        # NOTE (NickLucche) here we diverge from logic in other runners, as we
        # assume to only have whole mm items to process. Hence we avoid the
        # intrinsic dynamism that `scatter_mm_placeholders` introduces.
978
        for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
979
980
981
            assert pos_info.is_embed is None, (
                "Expected all positions to be contiguous and embeddings."
            )
982
            self.encoder_cache[mm_hash] = output
983
984

    def _gather_mm_embeddings(
985
986
        self,
        scheduler_output: "SchedulerOutput",
987
988
989
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        padded_total_num_scheduled_tokens = _get_padded_token_len(
990
991
            self.num_tokens_paddings, total_num_scheduled_tokens
        )
992
993
994
995
996
997

        is_mm_embed = self.is_mm_embed_cpu
        is_mm_embed[:padded_total_num_scheduled_tokens] = False
        mm_embeds = list[torch.Tensor]()
        req_start_idx = 0

998
        for req_id in self.input_batch.req_ids:
999
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1000
1001
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
1002

1003
1004
1005
1006
            # TODO unroll loop and assume/enforce --disable_chunked_mm_input
            # NOTE (NickLucche) here we diverge from logic in other runners, as
            # we assume to only have whole mm items to process. Hence we avoid
            # the intrinsic dynamism that `gather_mm_placeholders` introduces.
1007
1008
            for mm_feature in req_state.mm_features:
                pos_info = mm_feature.mm_position
1009
1010
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue
1023
1024
1025
1026
1027
1028
1029
1030

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens,
                )
                assert start_idx < end_idx

1031
                mm_hash = mm_feature.identifier
1032
                encoder_output = self.encoder_cache.get(mm_hash, None)
1033
                assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
1034

1035
1036
1037
                assert pos_info.is_embed is None, (
                    "Expected all positions to be contiguous and embeddings."
                )
1038
1039

                req_start_pos = req_start_idx + start_pos - num_computed_tokens
1040
                is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True
1041
1042

                # Only whole mm items are processed
1043
                mm_embeds.append(encoder_output)
1044

1045
1046
            req_start_idx += num_scheduled_tokens

1047
        is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device)
1048
1049
1050
1051
1052
1053

        return mm_embeds, is_mm_embed

    def _get_model_inputs(
        self,
        input_ids: torch.Tensor,
1054
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
1055
    ):
1056
        if self.supports_mm_inputs:
1057
1058
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

1059
1060
1061
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
1062
            inputs_embeds = self.model.get_input_embeddings(
1063
                input_ids,
1064
                multimodal_embeddings=mm_embeds,
1065
                is_multimodal=is_mm_embed,
1066
            )
1067

1068
1069
1070
1071
1072
1073
1074
1075
            return None, inputs_embeds
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            return input_ids, None

1076
1077
1078
1079
    @torch.no_grad()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
1080
        intermediate_tensors: IntermediateTensors | None = None,
1081
1082
1083
    ) -> ModelRunnerOutput:
        # Update cached state
        self._update_states(scheduler_output)
1084
        if not scheduler_output.total_num_scheduled_tokens:
1085
1086
1087
1088
            if not has_kv_transfer_group():
                # Return empty ModelRunnerOutput if there's no work to do.
                return EMPTY_MODEL_RUNNER_OUTPUT

1089
            return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
1090

1091
        if self.supports_mm_inputs:
1092
            # Run the multimodal encoder if any.
1093
            self._execute_mm_encoder(scheduler_output)
1094
            mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
1095
        else:
1096
1097
            mm_embed_inputs = None

1098
        torch_xla.sync(wait=False)
1099
        # Prepare inputs, the requests might be split into multiple
1100
1101
1102
1103
        # executions, combine the result of each execution.
        start_index = 0
        combined_selected_tokens: list[torch.Tensor] = []
        combined_logprobs: list[LogprobsLists] = []
1104
1105
1106
1107
1108
1109

        # NOTE: setup current batch's metadata for kv connector.
        # Currently, only verified with NixlConnector
        with set_forward_context(None, self.vllm_config):
            self.maybe_setup_kv_connector(scheduler_output)

1110
        while start_index < self.input_batch.num_reqs:
1111
1112
1113
            attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = (
                self._prepare_inputs(scheduler_output, start_index)
            )
1114
            input_ids, inputs_embeds = self._get_model_inputs(
1115
1116
                self.input_ids, mm_embed_inputs
            )
1117
            torch_xla.sync(wait=False)
1118
1119
            # Run the decoder
            with set_forward_context(
1120
1121
1122
1123
                attn_metadata,
                self.vllm_config,
                num_tokens=scheduler_output.total_num_scheduled_tokens,
            ):
1124
1125
1126
1127
1128
                hidden_states = self.model(
                    input_ids=input_ids,
                    positions=self.position_ids,
                    inputs_embeds=inputs_embeds,
                )
1129
            hidden_states = self.select_hidden_states(hidden_states, logits_indices)
1130
            logits = self.compute_logits(hidden_states)
1131
1132
1133
            tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
                self.input_batch, padded_num_reqs, self.device
            )
1134
            if scheduler_output.grammar_bitmask is not None:
1135
1136
1137
1138
1139
1140
                require_struct_decoding, grammar_bitmask_padded, arange = (
                    self.prepare_structured_decoding_input(logits, scheduler_output)
                )
                logits = self.structured_decode(
                    require_struct_decoding, grammar_bitmask_padded, logits, arange
                )
1141
            selected_token_ids = self.sample_from_logits_func(
1142
1143
                logits, tpu_sampling_metadata
            )
1144
1145
1146
1147
            # NOTE (NickLucche) Use the original logits (before any penalties or
            # temperature scaling) for the top-k logprobs. We can't enforce it
            # due to recompilations outside torch.compiled code, so just make
            # sure `sample_from_logits` does not modify the logits in-place.
1148
1149
1150
1151
1152
            logprobs = (
                self.gather_logprobs(logits, selected_token_ids)
                if tpu_sampling_metadata.logprobs
                else None
            )
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

            # Remove padding on cpu and keep dynamic op outside of xla graph.
            selected_token_ids = selected_token_ids.cpu()[:num_reqs]

            combined_selected_tokens.append(selected_token_ids)
            if tpu_sampling_metadata.logprobs:
                combined_logprobs.append(logprobs.tolists())

            start_index = end_index

1163
1164
1165
1166
1167
        # NOTE: current kv load and save get h2d/d2h copies involved.
        # Those copies are blocking. Once they become async., kv_save
        # should be called right after each single forward pass,
        # instead of the forwards of the entire input batch.
        self.maybe_wait_for_kv_save()
1168
1169
1170
        finished_sending, finished_recving = self.get_finished_kv_transfers(
            scheduler_output
        )
1171

1172
1173
1174
1175
1176
1177
1178
1179
1180
        selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
        if tpu_sampling_metadata.logprobs:

            def concat_lists(input_lists):
                result = []
                for input_list in input_lists:
                    result.extend(input_list)
                return result

1181
1182
1183
1184
1185
1186
1187
1188
1189
            logprobs_lists = LogprobsLists(
                logprob_token_ids=concat_lists(
                    [lp.logprob_token_ids for lp in combined_logprobs]
                ),
                logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]),
                sampled_token_ranks=concat_lists(
                    [lp.sampled_token_ranks for lp in combined_logprobs]
                ),
            )
1190
1191
        else:
            logprobs_lists = None
1192

1193
1194
        # Update the cache state concurrently. Code above will not block until
        # we use `selected_token_ids`. Add mark_step if post-processing changes
1195
        request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
1196
        discard_sampled_tokens_req_indices = []
1197
        num_reqs = self.input_batch.num_reqs
1198
1199
1200
        for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
            assert req_id is not None
            req_state = self.requests[req_id]
1201
1202
1203
1204
            seq_len = (
                req_state.num_computed_tokens
                + scheduler_output.num_scheduled_tokens[req_id]
            )
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
            if seq_len >= req_state.num_tokens:
                request_seq_lens.append((i, req_state, seq_len))
            else:
                # Ignore the sampled token from the partial request.
                # Rewind the generator state as if the token was not sampled.
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)

1215
1216
1217
1218
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

1219
        assert all(
1220
1221
            req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]
        ), "req_ids contains None"
1222
        req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
1223

1224
        prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
1225
        for req_id in self.input_batch.req_ids[:num_reqs]:
1226
1227
            prompt_logprobs_dict[req_id] = None

1228
1229
1230
        max_gen_len = selected_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_sampled_token_ids = selected_token_ids.tolist()
1231

1232
1233
1234
1235
1236
1237
1238
            # Mask out the sampled tokens that should not be sampled.
            # TODO: Keep in sync with gpu_model_runner.py, in particular
            #       the "else" case here
            for i in discard_sampled_tokens_req_indices:
                valid_sampled_token_ids[i].clear()

            # Append sampled tokens
1239
1240
1241
1242
1243
            for i, req_state, seq_len in request_seq_lens:
                token_id = valid_sampled_token_ids[i][0]
                self.input_batch.token_ids_cpu[i, seq_len] = token_id
                req_state.output_token_ids.append(token_id)
                self.input_batch.num_tokens[i] += 1
1244

1245
1246
1247
1248
        else:
            valid_mask = selected_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
            valid_sampled_token_ids = [
1249
                seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens)
1250
1251
1252
1253
            ]
            self.input_batch.num_tokens[:num_reqs] += gen_lens
            for i, req_state, seq_len in request_seq_lens:
                target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
1254
1255
1256
                self.input_batch.token_ids_cpu[i, target_slice] = (
                    valid_sampled_token_ids[i]
                )
1257
1258
                req_state.output_token_ids.extend(valid_sampled_token_ids[i])

1259
1260
1261
1262
        kv_connector_output = (
            None
            if (finished_sending is None and finished_recving is None)
            else KVConnectorOutput(
1263
1264
1265
                finished_sending=finished_sending,
                finished_recving=finished_recving,
            )
1266
        )
1267

1268
        model_runner_output = ModelRunnerOutput(
1269
            req_ids=req_ids,
1270
            req_id_to_index=self.input_batch.req_id_to_index,
1271
            sampled_token_ids=valid_sampled_token_ids,
1272
            logprobs=logprobs_lists,
1273
            prompt_logprobs_dict=prompt_logprobs_dict,
1274
            pooler_output=[],
1275
1276
            kv_connector_output=kv_connector_output,
        )
1277
1278
1279
1280
1281

        # Check there are no new graphs compiled - all the graphs should be
        # captured and compiled during warm up.
        self._verify_num_xla_graphs("execute_model")

1282
1283
        return model_runner_output

1284
1285
1286
1287
1288
    def update_config(self, overrides: dict[str, Any]) -> None:
        # TODO: TPU config may need extra validation
        # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
        allowed_config_names = {"load_config", "model_config"}
        for config_name, config_overrides in overrides.items():
1289
1290
            assert config_name in allowed_config_names, (
                f"Config `{config_name}` not supported. "
1291
                f"Allowed configs: {allowed_config_names}"
1292
            )
1293
1294
1295
1296
            config = getattr(self, config_name)
            new_config = update_config(config, config_overrides)
            setattr(self, config_name, new_config)

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
    def load_model(self) -> None:
        self.device = self.device_config.device

        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
        xm_tp_rank = xr.global_ordinal()
        with patch(
1311
1312
1313
1314
            "vllm.model_executor.layers.vocab_parallel_embedding."
            "get_tensor_model_parallel_rank",
            return_value=xm_tp_rank,
        ):
1315
1316
1317
            try:
                if self.use_spmd:
                    tpu_loader = TPUModelLoader(
1318
1319
                        load_config=self.vllm_config.load_config
                    )
1320
                    model = tpu_loader.load_model(
1321
                        vllm_config=self.vllm_config,
1322
                        model_config=self.vllm_config.model_config,
1323
1324
                        mesh=self.mesh,
                    )
1325
                else:
1326
                    model_loader = get_model_loader(self.load_config)
1327
1328
                    logger.info("Loading model from scratch...")
                    model = model_loader.load_model(
1329
1330
                        vllm_config=self.vllm_config, model_config=self.model_config
                    )
1331
1332
1333
1334
1335
1336
            except RuntimeError as e:
                raise RuntimeError(
                    f"Unable to load model, a likely reason is the model is "
                    "too large for the current device's HBM memory. "
                    "Consider switching to a smaller model "
                    "or sharding the weights on more chips. "
1337
1338
                    f"See the detailed error: {e}"
                ) from e
1339
        if self.lora_config is not None:
1340
            model = self.load_lora_model(model, self.vllm_config, self.device)
1341
            replace_set_lora(model)
1342

1343
1344
        # Sync all pending XLA execution during model initialization and weight
        # loading.
1345
        torch_xla.sync(wait=False)
1346
        xm.wait_device_ops()
1347
1348
        if not hasattr(self, "model"):
            self.model = model
1349
        self.sampler = TPUSampler()
1350

1351
    def reload_weights(self) -> None:
1352
        assert getattr(self, "model", None) is not None, (
1353
            "Cannot reload weights before model is loaded."
1354
        )
1355
1356
1357
1358
        model_loader = get_model_loader(self.load_config)
        logger.info("Reloading weights inplace...")
        model_loader.load_weights(self.model, model_config=self.model_config)

1359
    @torch.no_grad()
1360
    def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None:
1361
        if self.supports_mm_inputs:
1362
            input_ids = None
1363
1364
1365
            inputs_embeds = torch.zeros(
                (num_tokens, self.hidden_size), dtype=self.dtype, device=self.device
            )
1366
        else:
1367
            input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device)
1368
            inputs_embeds = None
1369
        actual_num_reqs = min(num_tokens, num_reqs)
1370
        position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device)
1371
        padded_num_slices = _get_padded_num_kv_cache_update_slices(
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
            num_tokens, self.max_num_reqs, self.block_size
        )
        num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to(
            self.device
        )
        slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to(
            self.device
        )
        block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to(
            self.device
        )
1383
        query_lens = [1] * num_reqs
1384
1385
1386
1387
1388
        query_start_loc = torch.cumsum(
            torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
        ).to(self.device)
        context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device)
        num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device)
1389
1390
1391
1392
1393
        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            context_lens=context_lens,
            query_start_loc=query_start_loc,
1394
            num_seqs=num_seqs,
1395
            num_kv_update_slices=num_kv_update_slices,
1396
            num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block,
1397
        )
1398

1399
        if self.supports_mm_inputs:
1400
1401
1402
            torch._dynamo.mark_dynamic(inputs_embeds, 0)
        else:
            torch._dynamo.mark_dynamic(input_ids, 0)
1403
1404
        torch._dynamo.mark_dynamic(position_ids, 0)
        torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
1405
1406
1407
        torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
        torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
        torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
1408

1409
        layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys()
1410
        per_layer_attn_metadata = {
1411
            layer_name: attn_metadata for layer_name in layer_names
1412
1413
        }

1414
1415
1416
1417
1418
1419
1420
1421
1422
        with (
            self.maybe_select_dummy_loras(
                self.lora_config, np.array([num_tokens], dtype=np.int32)
            ),
            set_forward_context(per_layer_attn_metadata, self.vllm_config, 0),
        ):
            out = self.model(
                input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds
            )
1423
        self._hidden_states_dtype = out.dtype
1424

1425
1426
1427
    def _set_active_loras(
        self, prompt_lora_mapping, token_lora_mapping, lora_requests
    ) -> None:
1428
        torch_xla.sync(wait=False)  # Captures input updates
1429
1430
1431
        super()._set_active_loras(
            prompt_lora_mapping, token_lora_mapping, lora_requests
        )
1432
        torch_xla.sync(wait=False)  # Captures metadata updates
1433

1434
    def _precompile_mm_encoder(self) -> None:
1435
        if not self.supports_mm_inputs:
1436
1437
            return

1438
1439
        # Pre-compile MM encoder for all supported data modalities.
        hf_config = self.vllm_config.model_config.hf_config
1440
1441
1442
1443
1444
1445
1446

        mm_budget = self.mm_budget
        assert mm_budget is not None

        max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality  # noqa: E501

        for mode, max_items_per_seq in max_items_per_seq_by_modality.items():
1447
            logger.info(
1448
1449
                "Compiling Multimodal %s Encoder with different input shapes.", mode
            )
1450
1451
            start = time.perf_counter()
            # No padding for MM encoder just yet.
1452
            for num_items in range(1, max_items_per_seq + 1):
1453
1454
                logger.info("  -- mode: %s items: %d", mode, num_items)
                batched_dummy_mm_inputs = self._get_mm_dummy_batch(
1455
1456
1457
                    mode,
                    num_items,
                )
1458
                # Run multimodal encoder.
1459
                torch_xla.sync(wait=False)
1460
                mm_embeds = self.model.get_multimodal_embeddings(
1461
1462
                    **batched_dummy_mm_inputs
                )
1463
                torch_xla.sync(wait=False)
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
                num_patches = mm_embeds[0].shape[0]
                items_size = num_patches * num_items

                # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm
                # embeddings are present. We assume `--disable-mm-chunked`,
                # hence only whole items can be scheduled. This implies we just
                # need to compile when `num_items` fit the (padded) `input_ids`
                for num_tokens in self.num_tokens_paddings:
                    if num_tokens >= items_size:
                        # XLA Workaround: if torch.zeros(..device) is used, XLA
                        # compiles a scalar+expansion op, which won't match
                        # the graph generated at runtime. CPU->TPU must be used
1476
1477
1478
                        placeholders_ids = torch.zeros(
                            num_tokens, dtype=torch.int32, device="cpu"
                        )
1479
                        # Align placeholders and actual num mm_embeddings.
1480
                        placeholders_ids[:items_size] = hf_config.image_token_index
1481
1482

                        placeholders_ids = placeholders_ids.to(self.device)
1483
1484
1485
1486

                        mm_mask = torch.tensor([False] * num_tokens)
                        mm_mask[:items_size] = True
                        mm_mask = mm_mask.to(self.device)
1487
                        # Assign outputs or the graph will be cut short.
1488
1489
1490
1491
                        a, b = self._get_model_inputs(
                            placeholders_ids,
                            mm_embed_inputs=([mm_embeds], mm_mask),
                        )
1492
                        assert a is None
1493
                        torch_xla.sync(wait=False)
1494
1495
1496
1497

            # Pre-compile `get_input_embeddings` when mm_embeddings are not
            # present. Chunk is only made of text, no mm_placeholders.
            for num_tokens in self.num_tokens_paddings:
1498
1499
1500
                placeholders_ids = torch.zeros(
                    num_tokens, dtype=torch.int32, device="cpu"
                )
1501
                placeholders_ids = placeholders_ids.to(self.device)
1502
1503
1504
1505
                a, b = self._get_model_inputs(
                    placeholders_ids,
                    mm_embed_inputs=None,
                )
1506
                assert a is None
1507
                torch_xla.sync(wait=False)
1508
1509
1510
1511

            xm.wait_device_ops()
            end = time.perf_counter()
            logger.info(
1512
1513
1514
1515
                "Multimodal %s Encoder compilation finished in in %.2f [secs].",
                mode,
                end - start,
            )
1516

1517
    def _precompile_backbone(self) -> None:
1518
1519
        logger.info("Compiling the model with different input shapes.")
        start = time.perf_counter()
1520
        for num_tokens in self.num_tokens_paddings:
1521
            logger.info("  -- num_tokens: %d", num_tokens)
1522
1523
1524
            self._dummy_run(
                num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req
            )
1525
            if self.most_model_len is not None:
1526
1527
1528
1529
1530
                self._dummy_run(
                    num_tokens,
                    self.num_reqs_most_model_len,
                    self.num_blocks_per_most_len_req,
                )
1531
1532
        xm.wait_device_ops()
        end = time.perf_counter()
1533
        logger.info("Compilation finished in %.2f [secs].", end - start)
1534
        self._update_num_xla_graphs("model backbone")
1535

1536
1537
1538
    def _precompile_select_hidden_states(self) -> None:
        # Compile hidden state selection function for bucketed
        # n_tokens x max_num_reqs. Graph is really small so this is fine.
1539
        logger.info("Compiling select_hidden_states with different input shapes.")
1540
1541
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
1542
        for num_tokens in self.num_tokens_paddings:
1543
1544
1545
            dummy_hidden = torch.zeros(
                (num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype
            )
1546
1547
            torch._dynamo.mark_dynamic(dummy_hidden, 0)
            for num_reqs in self.num_reqs_paddings:
1548
                indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
1549
1550
                torch._dynamo.mark_dynamic(indices, 0)
                self.select_hidden_states(dummy_hidden, indices)
1551
                logger.info("  -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs)
1552
1553
1554
1555
                # Requests can't be more than tokens. But do compile for the
                # next bigger value in case num_tokens uses bucketed padding.
                if num_reqs >= min(num_tokens, self.max_num_reqs):
                    break
1556
        xm.wait_device_ops()
1557
        end = time.perf_counter()
1558
        logger.info("Compilation finished in %.2f [secs].", end - start)
1559
        self._update_num_xla_graphs("select_hidden_states")
1560

1561
1562
    def _precompile_compute_logits(self) -> None:
        logger.info("Compiling compute_logits with different input shapes.")
1563
1564
1565
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
        for num_reqs in self.num_reqs_paddings:
1566
1567
1568
            dummy_hidden = torch.zeros(
                (num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype
            )
1569
1570
1571
1572
1573
1574
1575
1576
1577
            torch._dynamo.mark_dynamic(dummy_hidden, 0)
            self.compute_logits(dummy_hidden)
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("compute_logits")

    def _precompile_structured_decoding(self) -> None:
1578
        logger.info("Compiling structured_decoding with different input shapes.")
1579
1580
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
1581
1582
1583
1584
1585
1586
1587
1588
1589
            dummy_logits = torch.zeros(
                (num_reqs, self.vocab_size),
                device=self.device,
                dtype=self._hidden_states_dtype,
            )
            dummy_require_struct_decoding = self.require_structured_out_cpu[
                :num_reqs
            ].to(self.device)
            dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device)
1590
1591
1592
1593
            # The first dimension of the above 3 dummy tensors cannot be
            # mark_dynamic because some operations in structured_decode require
            # them to be static.
            arange = self.structured_decode_arange.to(self.device)
1594
1595
1596
1597
1598
1599
            self.structured_decode(
                dummy_require_struct_decoding,
                dummy_grammar_bitmask,
                dummy_logits,
                arange,
            )
1600
1601
1602
1603
1604
1605
1606
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("structured_decoding")

    def _precompile_sample_from_logits(self) -> None:
1607
        logger.info("Compiling sample_from_logits with different input shapes.")
1608
1609
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
1610
1611
1612
1613
1614
            dummy_logits = torch.zeros(
                (num_reqs, self.vocab_size),
                device=self.device,
                dtype=self._hidden_states_dtype,
            )
1615
1616
            # The first dimension of dummy_logits cannot be mark_dynamic
            # because some operations in the sampler require it to be static.
1617
1618
            for all_greedy in [False, True]:
                generate_params_if_all_greedy = not all_greedy
1619
1620
1621
1622
1623
1624
                sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
                    self.input_batch,
                    num_reqs,
                    self.device,
                    generate_params_if_all_greedy,
                )
1625
                sampling_metadata.all_greedy = all_greedy
1626
                with self.maybe_select_dummy_loras(
1627
1628
1629
                    self.lora_config, np.array([num_reqs], dtype=np.int32)
                ):
                    self.sample_from_logits_func(dummy_logits, sampling_metadata)
1630
1631
1632
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
1633
1634
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("sample_from_logits")
1635

1636
1637
1638
1639
    def _precompile_gather_logprobs(self) -> None:
        logger.info("Compiling gather_logprobs with different input shapes.")
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
1640
1641
1642
1643
1644
1645
            dummy_logits = torch.zeros(
                (num_reqs, self.vocab_size),
                device=self.device,
                dtype=self._hidden_states_dtype,
            )
            dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device)
1646
            with self.maybe_select_dummy_loras(
1647
1648
                self.lora_config, np.array([num_reqs], dtype=np.int32)
            ):
1649
                self.gather_logprobs(dummy_logits, dummy_tokens)
1650
1651
1652
1653
1654
1655
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("gather_logprobs")

1656
1657
1658
1659
    def capture_model(self) -> None:
        """
        Precompile all the subgraphs with possible input shapes.
        """
1660
1661
1662
1663
1664
1665
1666
1667
        with self.maybe_setup_dummy_loras(self.lora_config):
            self._precompile_mm_encoder()
            self._precompile_backbone()
            self._precompile_select_hidden_states()
            self._precompile_compute_logits()
            self._precompile_structured_decoding()
            self._precompile_sample_from_logits()
            self._precompile_gather_logprobs()
1668

1669
1670
1671
1672
1673
    def profile_run(
        self,
        num_tokens: int,
    ) -> None:
        # Profile with multimodal encoder & encoder cache.
1674
        if self.supports_mm_inputs:
1675
            if self.model_config.multimodal_config.skip_mm_profiling:
1676
                logger.info(
1677
                    "Skipping memory profiling for multimodal encoder and "
1678
1679
                    "encoder cache."
                )
1680
1681
1682
1683
1684
1685
1686
1687
1688
            else:
                mm_budget = self.mm_budget
                assert mm_budget is not None

                # TODO: handle encoder-decoder models once we support them.
                if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
                    # NOTE: Currently model is profiled with a single non-text
                    # modality with the max possible input tokens even when
                    # it supports multiple.
1689
                    dummy_modality = mm_budget.get_modality_with_max_tokens()
1690
1691
1692
                    max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
                        dummy_modality
                    ]
1693
1694
1695
1696
1697
1698
1699
1700
1701

                    logger.info(
                        "Encoder cache will be initialized with a budget of "
                        "%s tokens, and profiled with %s %s items of the "
                        "maximum feature size.",
                        encoder_budget,
                        max_mm_items_per_batch,
                        dummy_modality,
                    )
1702

1703
1704
1705
1706
1707
                    # Create dummy batch of multimodal inputs.
                    batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                        dummy_modality,
                        max_mm_items_per_batch,
                    )
1708

1709
1710
1711
1712
                    # Run multimodal encoder.
                    # Isolate encoder graph from post-processing to minimize
                    # impact of recompilation until it's fixed.
                    start = time.perf_counter()
1713
                    torch_xla.sync(wait=False)
1714
1715
1716
                    dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                        **batched_dummy_mm_inputs
                    )
1717
                    torch_xla.sync(wait=False)
1718
1719
1720
1721
                    xm.wait_device_ops()
                    end = time.perf_counter()
                    logger.info(
                        "Multimodal Encoder profiling finished in %.2f [secs].",
1722
1723
                        end - start,
                    )
1724
1725
1726
1727
1728

                    sanity_check_mm_encoder_outputs(
                        dummy_encoder_outputs,
                        expected_num_items=max_mm_items_per_batch,
                    )
1729

1730
                    # Cache the dummy encoder outputs.
1731
                    self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
1732
1733

        # Trigger compilation for general shape.
1734
1735
1736
        self._dummy_run(
            num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req
        )
1737
        if self.most_model_len is not None:
1738
1739
1740
1741
1742
            self._dummy_run(
                num_tokens,
                self.num_reqs_most_model_len,
                self.num_blocks_per_most_len_req,
            )
1743

1744
        torch_xla.sync(wait=False)
1745
1746
1747
1748
        xm.wait_device_ops()
        self.encoder_cache.clear()
        gc.collect()

1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
    def maybe_setup_cross_layer_kv_sharing(
        self,
        kv_caches: dict[str, torch.Tensor],
        kv_cache_config: KVCacheConfig,
    ) -> None:
        """
        Add layers that re-use KV cache to KV cache group of its target layer.
        Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
        """
        if not self.shared_kv_cache_layers:
            # No cross-layer KV sharing, return
            return

        add_kv_sharing_layers_to_kv_cache_groups(
            self.shared_kv_cache_layers,
            kv_cache_config.kv_cache_groups,
        )

1767
1768
        for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
            logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
1769
1770
            kv_caches[layer_name] = kv_caches[target_layer_name]

1771
1772
1773
1774
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1775
            kv_cache_config: Configuration for the KV cache, including the KV
1776
1777
            cache size of each layer
        """
1778
        if len(kv_cache_config.kv_cache_groups) > 1:
1779
            raise NotImplementedError(
1780
1781
                "Hybrid models with more than one KV cache type are not supported yet."
            )
1782

1783
1784
1785
1786
        if (
            kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
            != self.block_size
        ):
1787
1788
1789
1790
1791
1792
1793
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
                max_model_len=self.max_model_len,
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
1794
1795
1796
                block_sizes=[
                    kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
                ],
1797
1798
1799
                kernel_block_sizes=[
                    kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
                ],
1800
1801
            )
        # Verify dtype compatibility between block_table_cpu and input_batch
1802
1803
1804
1805
        assert (
            self.block_table_cpu.dtype
            == self.input_batch.block_table[0].get_cpu_tensor().dtype
        )
1806

1807
1808
1809
        kv_cache_sizes = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
            assert len(kv_cache_tensor.shared_by) == 1, (
1810
1811
                "KV cache tensor shared by multiple layers is not supported in TPU."
            )
1812
            kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
1813

1814
        kv_caches: dict[str, torch.Tensor] = {}
1815
1816
1817
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
1818
1819
1820
                tensor_size = kv_cache_sizes[layer_name]
                assert tensor_size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_size // kv_cache_spec.page_size_bytes  # noqa
1821
                if isinstance(kv_cache_spec, AttentionSpec):
1822
1823
1824
                    if self.use_spmd:
                        num_kv_heads = kv_cache_spec.num_kv_heads
                        assert self.original_parallel_config is not None
1825
                        tp_size = self.original_parallel_config.tensor_parallel_size
1826
1827
1828
                        # TODO: Handle kv cache duplication under SPMD mode.
                        assert num_kv_heads % tp_size == 0, (
                            f"num_kv_heads {num_kv_heads} must be divisible by "
1829
1830
                            f"tp_size {tp_size} under SPMD mode"
                        )
1831
                    kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
1832
1833
1834
1835
1836
                        num_blocks,
                        kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads,
                        kv_cache_spec.head_size,
                    )
1837
1838
                    dtype = kv_cache_spec.dtype

1839
1840
1841
                    tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to(
                        self.device
                    )
1842

1843
                    kv_caches[layer_name] = tpu_kv_cache
1844
1845
                else:
                    raise NotImplementedError
1846

1847
1848
        # Set up cross-layer KV cache sharing if needed
        self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config)
1849

1850
1851
1852
        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
1853
1854
            self.kv_caches,
        )
1855

1856
1857
1858
        if self.use_spmd:
            # Shard KV Cache
            for cache in self.kv_caches:
1859
                xs.mark_sharding(cache, self.mesh, (None, "x", None, None))
1860

1861
1862
1863
1864
        if has_kv_transfer_group():
            get_kv_transfer_group().register_kv_caches(kv_caches)
            get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)

1865
    def reset_dynamo_cache(self):
1866
1867
1868
1869
        # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs`
        # since the compiled model object of the language backbone of a
        # multimodal model needs to be extracted via `get_language_model`.
        if self.model_config.is_multimodal_model:
1870
            compiled_model = self.model.get_language_model().model
1871
1872
1873
1874
1875
        else:
            compiled_model = self.model.model
        if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
            logger.info("Clear dynamo cache and cached dynamo bytecode.")
            torch._dynamo.eval_frame.remove_from_cache(
1876
1877
                compiled_model.original_code_object
            )
1878
            compiled_model.compiled_codes.clear()
1879

1880
1881
1882
1883
1884
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
    def select_hidden_states(self, hidden_states, indices_do_sample):
        return hidden_states[indices_do_sample]

    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1885
    def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor:
1886
        return self.model.compute_logits(sample_hidden_states)
1887

1888
1889
1890
    # TODO: Under SPMD mode, sample_from_logits has correctness issue.
    #       Re-enable the torch.compile once the issue is fixed in torchxla.
    # @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1891
    def sample_from_logits(
1892
1893
        self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata
    ) -> torch.Tensor:
1894
        """
1895
        Sample with xla-friendly function. This function is to be traced
1896
1897
        separately from `forward` for lighter compilation overhead.
        """
1898
1899
1900
        if sampling_metadata.all_greedy:
            out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
        else:
1901
            out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids
1902
1903
        return out_tokens

1904
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1905
1906
1907
    def gather_logprobs(
        self, logits: torch.Tensor, sampled_tokens: torch.Tensor
    ) -> LogprobsTensors:
1908
1909
1910
1911
1912
1913
1914
1915
1916
        """
        Gather the top_logprobs with corresponding tokens. Use a fixed number
        of logprobs as an alternative to having multiple pre-compiled graphs.
        Select the number of logprobs actually demanded by each request on CPU.
        """
        logprobs = self.sampler.compute_logprobs(logits)
        return self.sampler.gather_logprobs(
            logprobs,
            self.model_config.max_logprobs,
1917
1918
            token_ids=sampled_tokens.squeeze(-1),
        )
1919

1920
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1921
1922
1923
1924
1925
1926
1927
    def structured_decode(
        self,
        require_struct_decoding: torch.Tensor,
        grammar_bitmask: torch.Tensor,
        logits: torch.Tensor,
        arange: torch.Tensor,
    ) -> torch.Tensor:
1928
1929
1930
        return torch.where(
            require_struct_decoding,
            self.apply_grammar_bitmask(logits, grammar_bitmask, arange),
1931
1932
            logits,
        )
1933

1934
1935
1936
1937
    def apply_grammar_bitmask(
        self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor
    ):
        assert logits.shape[0] == grammar_bitmask.shape[0]
1938
1939
        logits_cloned = logits.clone()
        for i in range(logits.shape[0]):
1940
1941
1942
1943
1944
            unpacked_bitmask = (
                torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :])
                & 1
            ) == 0
            unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size]
1945
            logits_cloned[i] = logits_cloned[i].masked_fill(
1946
1947
                unpacked_bitmask, -float("inf")
            )
1948
1949
        return logits_cloned

1950
1951
    def get_multimodal_embeddings(self, *args, **kwargs):
        return self.model.get_multimodal_embeddings(*args, **kwargs)
1952

1953
1954
1955
    def get_input_embeddings(self, *args, **kwargs):
        return self.model.get_input_embeddings(*args, **kwargs)

1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
    def prepare_structured_decoding_input(
        self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        grammar_bitmask = scheduler_output.grammar_bitmask
        assert grammar_bitmask is not None
        num_reqs, _ = logits.shape

        # Reset pre-allocated tensors
        self.grammar_bitmask_cpu.zero_()
        self.require_structured_out_cpu.zero_()

1967
        cumulative_mask_idx = 0
1968
        for req_id in scheduler_output.structured_output_request_ids:
1969
            if req_id not in self.input_batch.req_id_to_index:
1970
1971
                continue
            batch_index = self.input_batch.req_id_to_index[req_id]
1972
            self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
1973
1974
                grammar_bitmask[cumulative_mask_idx]
            )
1975
1976
1977
1978
1979
1980
            # It's not guaranteed that all requests in this batch require
            # structured output, so create a bool tensor to represent
            # the requests that need structured output.
            self.require_structured_out_cpu[batch_index] = True
            cumulative_mask_idx += 1

1981
1982
1983
1984
1985
        return (
            self.require_structured_out_cpu[:num_reqs].to(logits.device),
            self.grammar_bitmask_cpu[:num_reqs].to(logits.device),
            self.structured_decode_arange.to(logits.device),
        )
1986

1987
1988
1989
1990
1991
1992
    def _get_mm_dummy_batch(
        self,
        modality: str,
        max_items_per_batch: int,
    ) -> BatchedTensorInputs:
        """Dummy data for profiling and precompiling multimodal models."""
1993
1994
        assert self.mm_budget is not None

1995
        dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
1996
            model_config=self.model_config,
1997
            seq_len=self.max_model_len,
1998
            mm_counts={modality: 1},
1999
            cache=self.mm_budget.cache,
2000
        )
2001
2002
2003
        dummy_mm_data = dummy_decoder_data.multi_modal_data

        # Result in the maximum GPU consumption of the model
2004
2005
        dummy_mm_item = dummy_mm_data[modality][0]
        dummy_mm_items = [dummy_mm_item] * max_items_per_batch
2006

2007
        model = cast(SupportsMultiModal, self.model)
2008
2009
2010
2011
2012
2013
2014
2015
2016
        return next(
            grouped_mm_kwargs
            for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
                dummy_mm_items,
                device=self.device,
                pin_memory=self.pin_memory,
                merge_by_field_config=model.merge_by_field_config,
            )
        )
2017

2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029

def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
    logger.info("Preparing request paddings:")
    # assert min_req_size is power of 2
    assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
    paddings: list = []
    num = max(MIN_NUM_SEQS, min_req_size)
    while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
        paddings.append(num)
        logger.info("    %d", num)
        num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
    return paddings
2030
2031


2032
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
2033
    res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
2034
    return min(res, upper_limit)
2035
2036


2037
2038
2039
2040
def _get_token_paddings(
    min_token_size: int, max_token_size: int, padding_gap: int
) -> list[int]:
    """Generate a list of padding size, starting from min_token_size,
2041
    ending with a number that can cover max_token_size
2042

2043
2044
2045
    If padding_gap == 0 then:
        increase 2X each time (exponential)
    else:
2046
        first increase the size to twice,
2047
        then increase the padding size by padding_gap.
2048
    """
2049
2050
    # assert min_token_size is power of 2
    assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
2051
2052
    paddings = []
    num = min_token_size
2053
2054

    if padding_gap == 0:
2055
        logger.info("Using exponential token paddings:")
2056
        while True:
2057
2058
            logger.info("    %d", num)
            paddings.append(num)
2059
2060
            if num >= max_token_size:
                break
2061
2062
            num *= 2
    else:
2063
        logger.info("Using incremental token paddings:")
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
        while num <= padding_gap:
            logger.info("    %d", num)
            paddings.append(num)
            num *= 2
        num //= 2
        while num < max_token_size:
            num += padding_gap
            logger.info("    %d", num)
            paddings.append(num)

2074
2075
2076
2077
    return paddings


def _get_padded_token_len(paddings: list[int], x: int) -> int:
2078
    """Return the first element in paddings list greater or equal to x."""
2079
2080
2081
    index = bisect.bisect_left(paddings, x)
    assert index < len(paddings)
    return paddings[index]
2082
2083


2084
2085
2086
def _get_padded_num_kv_cache_update_slices(
    num_tokens: int, max_num_reqs: int, page_size: int
) -> int:
2087
2088
    """Calculates the padded number of KV cache update slices to avoid
    recompilation."""
2089
2090
2091
2092
2093
    # NOTE(chengjiyao): let's say R_i is the token num for i-th request,
    # so it occupies most 2 + R_i // page_size pages. The total maximum
    # possible number of pages needed is sum(2 + R_i // page_size), which
    # is <= 2 * max_num_reqs + sum(R_i) // page_size
    # = 2 * max_num_reqs + num_tokens // page_size
2094
2095
2096
2097
2098
    padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
    padded_num_slices = min(padded_num_slices, num_tokens)
    return padded_num_slices


2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
    """Find the optimum number of slices to copy per Pallas program instance.

    Increasing the number of slices copied in one instance of the kernel program
    will increase HBM bandwidth utilization via more in-flight DMAs.

    However, it will also use more VMEM, and experimentally, we observed
    performance regression at 128 slices on v6e, likely due to running
    out of scalar registers. Thus this function will limit the number of
    slices to 64.
    """
2110
2111
2112
    # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
    # calculate num_slices_per_block based on 16MB in case any register spills.
    vmem_limit = 16 * 1024 * 1024
2113
2114
2115
2116
2117
2118
2119
2120
    num_slices_per_block = vmem_limit // page_size_bytes
    assert num_slices_per_block > 0, "Number of slices should be positive"
    num_slices_per_block = prev_power_of_2(num_slices_per_block)
    if num_slices_per_block > 64:
        num_slices_per_block = 64
    return num_slices_per_block


2121
2122
2123
2124
2125
2126
def replace_set_lora(model):
    def _tpu_set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
2127
        embeddings_tensor: torch.Tensor | None,
2128
2129
2130
2131
    ):
        # TODO: The integer index leads to a recompilation, but converting it
        # to a tensor doesn't seem to work anymore. This might be fixed with a
        # later release of torch_xla.
2132
        self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
2133
        torch_xla.sync(wait=False)
2134
2135
2136

    def _tpu_reset_lora(self, index: int):
        self._original_reset_lora(index)
2137
        torch_xla.sync(wait=False)
2138
2139
2140
2141
2142
2143

    for _, module in model.named_modules():
        if isinstance(module, BaseLayerWithLoRA):
            module._original_set_lora = module.set_lora
            module._original_reset_lora = module.reset_lora
            module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
2144
            module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__)