tpu_model_runner.py 76.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import bisect
4
import gc
5
import time
6
from typing import TYPE_CHECKING, Optional, cast
7
8
9
10
11
12
13
from unittest.mock import patch

import numpy as np
import torch
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
14
import torch_xla.distributed.spmd as xs
15
16
import torch_xla.runtime as xr

17
import vllm.envs as envs
18
19
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
20
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
21
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
22
from vllm.forward_context import set_forward_context
23
from vllm.logger import init_logger
24
from vllm.lora.layers import BaseLayerWithLoRA
25
from vllm.model_executor.model_loader import get_model_loader
26
from vllm.model_executor.model_loader.tpu import TPUModelLoader
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
29
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
                                    PlaceholderRange)
30
from vllm.multimodal.utils import group_mm_inputs_by_modality
31
from vllm.sequence import IntermediateTensors
32
33
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
                        is_pin_memory_available)
34
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
35
                                               PallasMetadata)
36
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
37
38
39
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
                                        KVCacheConfig, KVCacheSpec,
                                        SlidingWindowSpec)
40
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
41
                             ModelRunnerOutput)
42
43
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
44
45
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
46
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
47

48
49
from .utils import (initialize_kv_cache_for_kv_sharing,
                    sanity_check_mm_encoder_outputs)
50

51
if TYPE_CHECKING:
52
    from vllm.v1.core.sched.output import SchedulerOutput
53
54
55
56
57
58

logger = init_logger(__name__)

# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
59
INVALID_TOKEN_ID = -1
60
61
# Smallest output size
MIN_NUM_SEQS = 8
62
63


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#########################################################
# Ways to avoid recompilation
#########################################################
#
# The model executor has two primary components:
# 1. preparing the model and sampler inputs
# 2. executing the model and sampler.
# The core idea is to avoid any TPU computation during input preparation. For
# better compilation tracking and increased flexibility, the model execution and
# sampler are divided into several distinct components.
#
# Below are the detailed steps:
#
# Step 1
# It is recommended to avoid TPU operations when preparing the model and sampler
# inputs. CPU tensors can be prepared and transferred to the XLA device using
# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
# compilation.
#
# Step 2
# The TPU execution should be decomposed into subgraphs (4 at the moment):
# 1. the main model
# 2. selecting hidden states for each request
# 3. sampler
# 4. encoder.
# Each subgraph should be decorated in a torch.compile. This is used to make
# sure that we have the same subgraph topology in both dummy_run and
# xecute_model. The results from these subgraphs should either be passed to
# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
# subsequent processing on the CPU.
#
# Step 3
# The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate
# pre-compilation.
99
class TPUModelRunner(LoRAModelRunnerMixin):
100
101
102
103
104

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
105
        original_parallel_config: Optional[ParallelConfig] = None,
106
107
108
109
110
111
112
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
113
        self.original_parallel_config = original_parallel_config
114
115
116
117
118
119
120
121
122
123
124
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
        self.device_config = vllm_config.device_config

        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
        self.device = device
125
        self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
126

127
128
129
130
131
132
133
134
        # SPMD Related
        self.use_spmd = envs.VLLM_XLA_USE_SPMD
        if self.use_spmd:
            num_devices = xr.global_runtime_device_count()
            mesh_shape = (num_devices, 1)
            device_ids = np.array(range(num_devices))
            self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))

135
        self.enforce_eager = model_config.enforce_eager
136
137
138
139

        self.num_xla_graphs = 0
        self._update_num_xla_graphs("init")

140
141
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
142
143
144
145
146
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]
147
        self._hidden_states_dtype = self.dtype
148
149
150
151
152
153

        self.is_multimodal_model = model_config.is_multimodal_model
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
154
155
        # InputBatch needs to work with sampling tensors greater than padding
        # to avoid dynamic shapes. Also, avoid suboptimal alignment.
156
        self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
157
158
159
160
161
162
163
        self.num_tokens_paddings = _get_token_paddings(
            min_token_size=16,
            max_token_size=scheduler_config.max_num_batched_tokens,
            padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
        # In case `max_num_tokens < max(num_tokens_paddings)` use the actual
        # padded max value to pre-allocate data structures and pre-compile.
        self.max_num_tokens = self.num_tokens_paddings[-1]
164
165
166
167
168
169
170
171
172

        # Model-related.
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
        self.hidden_size = model_config.get_hidden_size()
173
        self.vocab_size = model_config.get_vocab_size()
174

175
176
177
        if self.lora_config is not None:
            self.vocab_size += self.lora_config.lora_extra_vocab_size

178
179
180
181
182
183
184
185
186
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.uses_mrope = model_config.uses_mrope
        # TODO: Support M-RoPE (e.g, Qwen2-VL)
        assert not self.uses_mrope, "TPU does not support M-RoPE yet."

        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
187
            mm_registry=self.mm_registry,
188
189
190
191
192
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size

        # Lazy initialization
193
        self.model: nn.Module  # Set after load_model
194
195
196
197
198
199
        self.kv_caches: list[torch.Tensor] = []
        # req_id -> (input_id -> encoder_output)
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

        # Request states.
        self.requests: dict[str, CachedRequestState] = {}
200

201
202
203
204
205
206
207
208
        # Initialize input batch early to avoid AttributeError in _update_states
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
            vocab_size=self.model_config.get_vocab_size(),
209
            block_sizes=[self.block_size],
210
211
        )

212
213
214
215
216
217
218
219
220
221
222
223
224
        # Cached torch/numpy tensor
        # The pytorch tensor and numpy array share the same buffer.
        # Sometimes the numpy op is faster so we create both.
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu")

        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu")
        self.positions_np = self.positions_cpu.numpy()

        self.block_table_cpu = torch.zeros(
225
            (self.max_num_reqs, self.max_num_blocks_per_req),
226
            dtype=torch.int32,
227
228
229
230
231
232
233
234
235
236
237
238
239
            device="cpu")

        self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()

        self.seq_lens_cpu = torch.zeros(self.max_num_tokens,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
240
241
242

        # Range tensor with values [0 .. self.max_num_tokens - 1].
        # Used to initialize positions / context_lens / seq_lens
243
244
        # Keep in int64 to avoid overflow with long context
        self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64)
245
246
        self.num_reqs_paddings = _get_req_paddings(
            min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
247

248
249
250
251
252
253
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}

254
255
256
257
258
259
260
261
262
263
264
265
266
267
        # tensors for structured decoding
        self.grammar_bitmask_cpu = torch.zeros(
            (self.max_num_reqs, cdiv(self.vocab_size, 32)),
            dtype=torch.int32,
            device="cpu",
            pin_memory=self.pin_memory)
        self.require_structured_out_cpu = torch.zeros(
            (self.max_num_reqs, 1),
            dtype=torch.bool,
            device="cpu",
            pin_memory=self.pin_memory)
        self.structured_decode_arange = torch.arange(
            0, 32, device="cpu", pin_memory=self.pin_memory)

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        # Get maximum number of mm items per modality (batch size).
        self.max_num_mm_items_by_modality = dict()
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
            max_tokens_by_modality_dict = (
                MULTIMODAL_REGISTRY.
                get_max_tokens_per_item_by_nonzero_modality(self.model_config))
            for modality, max_tokens in max_tokens_by_modality_dict.items():
                # Check how many items of this modality can be supported by
                # the encoder budget.
                encoder_budget = min(self.max_num_encoder_input_tokens,
                                     self.encoder_cache_size)

                max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                       max_tokens)

                # Check how many items of this modality can be supported by
                # the decoder budget.
                max_mm_items_per_req = self.mm_registry.\
                    get_mm_limits_per_prompt(self.model_config)[modality]

                # NOTE: We do not consider max_num_batched_tokens on purpose
                # because the multimodal embeddings can be generated in advance
                # and chunked prefilled.
                max_num_mm_items_decoder_budget = self.max_num_reqs * \
                    max_mm_items_per_req

                max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                       max_num_mm_items_decoder_budget)
                self.max_num_mm_items_by_modality[modality] = max_num_mm_items

299
300
301
302
303
304
305
306
307
        if not self.use_spmd:
            self.sample_from_logits_func = torch.compile(
                self.sample_from_logits,
                backend="openxla",
                fullgraph=True,
                dynamic=False)
        else:
            self.sample_from_logits_func = self.sample_from_logits

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    def _update_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        total_cached_graphs = xr.get_num_cached_compilation_graph()
        new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
        if new_compiled_graphs == 0:
            return

        logger.info("Add new %d compiled XLA graphs due to %s",
                    new_compiled_graphs, case_str)
        self.num_xla_graphs += new_compiled_graphs

    def _verify_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        curr_cached_graph = xr.get_num_cached_compilation_graph()
        assert self.num_xla_graphs == curr_cached_graph, (
            "Recompilation after warm up is detected during {}."
            " num_xla_graphs = {} curr_cached_graph = {}".format(
                case_str, self.num_xla_graphs, curr_cached_graph))

333
334
335
336
337
338
339
340
    def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        Returns:
341
            True if there is a new/resumed/paused/finished request.
342
343
344
345
346
            If False, we can skip copying SamplingMetadata to the GPU.
        """
        # Remove finished requests from the cached states.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
347
            self.encoder_cache.pop(req_id, None)
348
349
350
351
352
353
354

        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
355
        removed_req_indices: list[int] = []
356
357
358
359
360
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)

361
362
363
364
365
366
367
368
        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            assert req_index is not None
            removed_req_indices.append(req_index)

386
        req_ids_to_add: list[str] = []
387
388
389
390
391
392
393
394
395
396
397
        # Add new requests to the cached states.
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params

            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
                prompt_token_ids=new_req_data.prompt_token_ids,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
                sampling_params=sampling_params,
398
                generator=None,
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                output_token_ids=[],
                lora_request=new_req_data.lora_request,
            )

            req_ids_to_add.append(req_id)

        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
            req_state = self.requests[req_id]

            # Update the cached states.
            req_state.num_computed_tokens = req_data.num_computed_tokens
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
                req_data.num_computed_tokens)
433
434
            self.input_batch.block_table.append_row(req_data.new_block_ids,
                                                    req_index)
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
452

453
454
455
456
457
        return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0

    def get_model(self) -> nn.Module:
        return self.model

458
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
459
        """
460
        Generates the KVCacheSpec by parsing the kv cache format from each
461
462
        Attention module in the static forward context.
        Returns:
463
            KVCacheSpec: A dictionary mapping layer names to their KV cache
464
465
466
            format. Layers that do not need KV cache are not included.
        """

467
        layers = get_layers_from_vllm_config(self.vllm_config, Attention)
468
        block_size = self.vllm_config.cache_config.block_size
469
        kv_cache_spec: dict[str, KVCacheSpec] = {}
470
        for layer_name, attn_module in layers.items():
471
472
473
474
475
476
477
478
479
480
481
482
            if (kv_tgt_layer :=
                    attn_module.kv_sharing_target_layer_name) is not None:
                # The layer doesn't need its own KV cache and will use that of
                # the target layer. We skip creating a KVCacheSpec for it, so
                # that KV cache management logic will act as this layer does
                # not exist, and doesn't allocate KV cache for the layer. This
                # enables the memory saving of cross-layer kv sharing, allowing
                # a given amount of memory to accommodate longer context lengths
                # or enable more requests to be processed simultaneously.
                self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                continue

483
            if attn_module.attn_type == AttentionType.DECODER:
484
485
486
487
488
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
489
                        dtype=self.kv_cache_dtype,
490
491
492
493
494
495
496
497
                        sliding_window=attn_module.sliding_window,
                        use_mla=False,
                    )
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
498
                        dtype=self.kv_cache_dtype,
499
500
                        use_mla=False,
                    )
501
502
503
504
505
506
507
508
509
510
511
512
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec

513
    def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
514
515
516
517
518
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

519
520
521
522
        # Get the number of scheduled tokens for each request.
        num_scheduled_tokens_per_req = []
        max_num_scheduled_tokens_all_reqs = 0
        for req_id in self.input_batch.req_ids[:num_reqs]:
523
            assert req_id is not None
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
            num_scheduled_tokens_per_req.append(num_tokens)
            max_num_scheduled_tokens_all_reqs = max(
                max_num_scheduled_tokens_all_reqs, num_tokens)
        num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
                                                dtype=np.int32)
        assert max_num_scheduled_tokens_all_reqs > 0

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        # For each scheduled token, what are the corresponding req index.
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens_per_req)

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # For each scheduled token, what is its position in corresponding req.
        arange = np.concatenate(
            [self.arange_np[:n] for n in num_scheduled_tokens_per_req])

        # Get positions.
        positions_np = self.positions_np[:total_num_scheduled_tokens]
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])

        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
560
561
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
                           0,
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])

        # Calculate the slot mapping.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
        # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
577
        block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
578
579
580
581
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
582
               out=self.input_batch.block_table[0].
583
               slot_mapping_np[:total_num_scheduled_tokens])
584
585
586
587
588

        # Prepare the attention metadata.
        self.query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens_per_req,
                  out=self.query_start_loc_np[1:num_reqs + 1])
589
        self.query_start_loc_np[num_reqs + 1:] = 1
590
591
592
593
594
595

        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens_per_req)

        # Do the padding and copy the tensors to the TPU.
596
        padded_total_num_scheduled_tokens = _get_padded_token_len(
597
            self.num_tokens_paddings, total_num_scheduled_tokens)
598
599
600
        # Zero out to avoid spurious values from prev iteration (last cp chunk)
        self.input_ids_cpu[
            total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
601
602
603
604
605
606
        self.input_ids = self.input_ids_cpu[:
                                            padded_total_num_scheduled_tokens].to(
                                                self.device)
        self.position_ids = self.positions_cpu[:
                                               padded_total_num_scheduled_tokens].to(
                                                   self.device)
607
        self.input_batch.block_table[0].slot_mapping_cpu[
608
609
            total_num_scheduled_tokens:] = _PAD_SLOT_ID
        slot_mapping = (
610
            self.input_batch.block_table[0].
611
612
            slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
                self.device))
613
614
        block_tables = self.block_table_cpu[:self.max_num_reqs]
        block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
615
            self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
616
617
        block_tables = block_tables.to(self.device)
        query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
618
            self.device)
619
        seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
620

621
622
623
624
625
626
627
628
629
630
631
        if self.lora_config is not None:
            # We need to respect padding when activating LoRA adapters
            padded_num_scheduled_tokens_per_req = np.copy(
                num_scheduled_tokens_per_req
            )  # Copying to avoid accidental state corruption bugs
            padded_num_scheduled_tokens_per_req[-1] += \
                padded_total_num_scheduled_tokens - total_num_scheduled_tokens

            self.set_active_loras(self.input_batch,
                                  padded_num_scheduled_tokens_per_req)

632
633
        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
634
            block_tables=block_tables,
635
636
            context_lens=seq_lens,
            query_start_loc=query_start_loc,
637
638
639
            num_seqs=torch.tensor([num_reqs],
                                  dtype=torch.int32,
                                  device=self.device),
640
        )
641
642
643
644
645
        # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
        # request in the batch. While we should not sample any token from this
        # partial request, we do so for simplicity. We will ignore the sampled
        # token from the partial request.
        # TODO: Support prompt logprobs.
646
647
        padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
            num_reqs, self.max_num_reqs)
648
649
        # Indices at which we sample (positions of last token in the sequence).
        # Padded to avoid recompiling when `num_reqs` varies.
650
651
        logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
        logits_indices = logits_indices.to(self.device)
652

653
654
655
656
657
658
659
660
661
662
663
        if self.lora_config is not None:
            # We need to respect padding when activating LoRA adapters
            padded_num_scheduled_tokens_per_req = np.copy(
                num_scheduled_tokens_per_req
            )  # Copying to avoid accidental state corruption bugs
            padded_num_scheduled_tokens_per_req[-1] += \
                padded_total_num_scheduled_tokens - total_num_scheduled_tokens

            self.set_active_loras(self.input_batch,
                                  padded_num_scheduled_tokens_per_req)

664
665
666
667
668
669
670
        layer_names = get_layers_from_vllm_config(self.vllm_config,
                                                  Attention).keys()
        per_layer_attn_metadata = {
            layer_name: attn_metadata
            for layer_name in layer_names
        }
        return per_layer_attn_metadata, logits_indices, padded_num_reqs
671

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
    def _scatter_placeholders(
        self,
        embeds: torch.Tensor,
        is_embed: Optional[torch.Tensor],
    ) -> torch.Tensor:
        if is_embed is None:
            return embeds

        placeholders = embeds.new_full(
            (is_embed.shape[0], embeds.shape[-1]),
            fill_value=torch.nan,
        )
        placeholders[is_embed] = embeds
        return placeholders

    def _gather_placeholders(
        self,
        placeholders: torch.Tensor,
        is_embed: Optional[torch.Tensor],
    ) -> torch.Tensor:
        if is_embed is None:
            return placeholders

        return placeholders[is_embed]

    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
698
699
700
701
702
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
703
704
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
705
706
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
707
708
709
710
711

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
712
713
714
715
716
717
718
719
720
721
722
723
724

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
725
726
727
728
            batched_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_mm_inputs,
                device=self.device,
            )
729
730
731
732
733
734
735
736

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
737
            xm.mark_step()
738
739
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)
740
            xm.mark_step()
741

742
743
744
745
746
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

747
748
749
750
751
752
            if isinstance(curr_group_outputs, torch.Tensor):
                encoder_outputs.append(curr_group_outputs)
            else:
                assert isinstance(curr_group_outputs, (list, tuple))
                for output in curr_group_outputs:
                    encoder_outputs.append(output)
753
754

        # Cache the encoder outputs.
755
756
757
        # NOTE (NickLucche) here we diverge from logic in other runners, as we
        # assume to only have whole mm items to process. Hence we avoid the
        # intrinsic dynamism that `scatter_mm_placeholders` introduces.
758
759
760
761
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
762
763
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}
764
765
766
            assert pos_info.is_embed is None, "Expected all positions to be"\
                " contiguous and embeddings."
            self.encoder_cache[req_id][input_id] = output
767
768

    def _gather_mm_embeddings(
769
770
771
        self,
        scheduler_output: "SchedulerOutput",
    ) -> list[torch.Tensor]:
772
        mm_embeds: list[torch.Tensor] = []
773
774
775
776
777
778
        for req_id in self.input_batch.req_ids:
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
779
780
781
782
            # TODO unroll loop and assume/enforce --disable_chunked_mm_input
            # NOTE (NickLucche) here we diverge from logic in other runners, as
            # we assume to only have whole mm items to process. Hence we avoid
            # the intrinsic dynamism that `gather_mm_placeholders` introduces.
783
            for i, pos_info in enumerate(mm_positions):
784
785
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
801
802
                assert pos_info.is_embed is None, "Expected all positions to"\
                " be contiguous and embeddings."
803
                encoder_output = self.encoder_cache[req_id][i]
804
                mm_embeds.append(encoder_output)
805
        return mm_embeds
806

807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
    def _get_model_inputs(self, input_ids: torch.Tensor,
                          mm_embeds: list[torch.Tensor]):
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            if mm_embeds:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, mm_embeds)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            return None, inputs_embeds
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            return input_ids, None

826
827
828
829
    @torch.no_grad()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
830
        intermediate_tensors: Optional[IntermediateTensors] = None,
831
832
833
    ) -> ModelRunnerOutput:
        # Update cached state
        self._update_states(scheduler_output)
834
        if not scheduler_output.total_num_scheduled_tokens:
835
            # Return empty ModelRunnerOutput if there's no work to do.
836
            return EMPTY_MODEL_RUNNER_OUTPUT
837

838
839
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
840
841
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
842
        else:
843
            mm_embeds = []
844
        xm.mark_step()
845
        # Prepare inputs
846
847
        attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
            scheduler_output)
848
849
850
        input_ids, inputs_embeds = self._get_model_inputs(
            self.input_ids, mm_embeds)
        xm.mark_step()
851
        num_reqs = self.input_batch.num_reqs
852
        # Run the decoder
853
854
855
856
        with set_forward_context(
                attn_metadata,
                self.vllm_config,
                num_tokens=scheduler_output.total_num_scheduled_tokens):
857
            hidden_states = self.model(
858
859
860
                input_ids=input_ids,
                positions=self.position_ids,
                inputs_embeds=inputs_embeds,
861
            )
862
863
        hidden_states = self.select_hidden_states(hidden_states,
                                                  logits_indices)
864
        logits = self.compute_logits(hidden_states)
865
866
        tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
            from_input_batch(self.input_batch, padded_num_reqs, self.device)
867
868
869
870
871
872
        if scheduler_output.grammar_bitmask is not None:
            require_struct_decoding, grammar_bitmask_padded, arange = \
                self.prepare_structured_decoding_input(logits, scheduler_output)
            logits = self.structured_decode(require_struct_decoding,
                                            grammar_bitmask_padded, logits,
                                            arange)
873
874
        selected_token_ids = self.sample_from_logits_func(
            logits, tpu_sampling_metadata)
875
876
877
878
879
880
881
        # NOTE (NickLucche) Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs. We can't enforce it due
        # to recompilations outside torch.compiled code, so just make sure
        # `sample_from_logits` does not modify the logits in-place.
        logprobs = self.gather_logprobs(logits, selected_token_ids) \
            if tpu_sampling_metadata.logprobs else None

882
        # Remove padding on cpu and keep dynamic op outside of xla graph.
883
        selected_token_ids = selected_token_ids.cpu()[:num_reqs]
884
885
        logprobs_lists = logprobs.tolists() \
            if tpu_sampling_metadata.logprobs else None
886

887
888
        # Update the cache state concurrently. Code above will not block until
        # we use `selected_token_ids`. Add mark_step if post-processing changes
889
        request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
890
        discard_sampled_tokens_req_indices = []
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
            assert req_id is not None
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
            if seq_len >= req_state.num_tokens:
                request_seq_lens.append((i, req_state, seq_len))
            else:
                # Ignore the sampled token from the partial request.
                # Rewind the generator state as if the token was not sampled.
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)

906
907
908
909
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

910
911
912
        assert all(
            req_id is not None for req_id in
            self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
913
        req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
914

915
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
916
        for req_id in self.input_batch.req_ids[:num_reqs]:
917
918
            prompt_logprobs_dict[req_id] = None

919
920
921
        max_gen_len = selected_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_sampled_token_ids = selected_token_ids.tolist()
922

923
924
925
926
927
928
929
            # Mask out the sampled tokens that should not be sampled.
            # TODO: Keep in sync with gpu_model_runner.py, in particular
            #       the "else" case here
            for i in discard_sampled_tokens_req_indices:
                valid_sampled_token_ids[i].clear()

            # Append sampled tokens
930
931
932
933
934
            for i, req_state, seq_len in request_seq_lens:
                token_id = valid_sampled_token_ids[i][0]
                self.input_batch.token_ids_cpu[i, seq_len] = token_id
                req_state.output_token_ids.append(token_id)
                self.input_batch.num_tokens[i] += 1
935

936
937
938
939
940
941
942
943
944
945
946
947
948
949
        else:
            valid_mask = selected_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
            valid_sampled_token_ids = [
                seq.tolist()
                for seq in selected_token_ids[valid_mask].split(gen_lens)
            ]
            self.input_batch.num_tokens[:num_reqs] += gen_lens
            for i, req_state, seq_len in request_seq_lens:
                target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
                self.input_batch.token_ids_cpu[
                    i, target_slice] = valid_sampled_token_ids[i]
                req_state.output_token_ids.extend(valid_sampled_token_ids[i])

950
        model_runner_output = ModelRunnerOutput(
951
            req_ids=req_ids,
952
            req_id_to_index=self.input_batch.req_id_to_index,
953
            sampled_token_ids=valid_sampled_token_ids,
954
            spec_token_ids=None,
955
            logprobs=logprobs_lists,
956
            prompt_logprobs_dict=prompt_logprobs_dict,
957
        )
958
959
960
961
962

        # Check there are no new graphs compiled - all the graphs should be
        # captured and compiled during warm up.
        self._verify_num_xla_graphs("execute_model")

963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
        return model_runner_output

    def load_model(self) -> None:
        self.device = self.device_config.device

        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
        xm_tp_rank = xr.global_ordinal()
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
982
983
984
985
986
987
988
            if self.use_spmd:
                tpu_loader = TPUModelLoader(
                    load_config=self.vllm_config.load_config)
                model = tpu_loader.load_model(
                    vllm_config=self.vllm_config,
                    model_config=self.vllm_config.model_config,
                    mesh=self.mesh)
989
            else:
990
991
992
993
994
995
996
997
998
999
1000
1001
                # model = get_model(vllm_config=self.vllm_config)
                model_loader = get_model_loader(self.load_config)
                if not hasattr(self, "model"):
                    logger.info("Loading model from scratch...")
                    model = model_loader.load_model(
                        vllm_config=self.vllm_config,
                        model_config=self.model_config)
                else:
                    logger.info("Model was already initialized. \
                            Loading weights inplace...")
                    model_loader.load_weights(self.model,
                                              model_config=self.model_config)
1002
1003
1004
1005
        if self.lora_config is not None:
            model = self.load_lora_model(model, self.model_config,
                                         self.scheduler_config,
                                         self.lora_config, self.device)
1006
            replace_set_lora(model)
1007

1008
1009
        # Sync all pending XLA execution during model initialization and weight
        # loading.
1010
1011
        xm.mark_step()
        xm.wait_device_ops()
1012
1013
        if not hasattr(self, "model"):
            self.model = model
1014
        self.sampler = TPUSampler()
1015

1016
    @torch.no_grad()
1017
    def _dummy_run(self, num_tokens: int) -> None:
1018
1019
1020
1021
1022
1023
1024
        if self.is_multimodal_model:
            input_ids = None
            inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
                                        dtype=self.dtype,
                                        device=self.device)
        else:
            input_ids = torch.zeros((num_tokens),
1025
                                    dtype=torch.int32).to(self.device)
1026
            inputs_embeds = None
1027
        actual_num_reqs = min(num_tokens, self.max_num_reqs)
1028
        position_ids = torch.zeros(num_tokens,
1029
                                   dtype=torch.int32).to(self.device)
1030
        slot_mapping = torch.zeros(num_tokens,
1031
                                   dtype=torch.int64).to(self.device)
1032
1033
        block_tables = torch.zeros(
            (self.max_num_reqs, self.block_table_cpu.shape[1]),
1034
            dtype=torch.int32).to(self.device)
1035
        query_lens = [1] * self.max_num_reqs
1036
1037
1038
1039
        query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
                                                    dtype=torch.int32),
                                       dim=0,
                                       dtype=torch.int32).to(self.device)
1040
        context_lens = torch.ones((self.max_num_reqs, ),
1041
                                  dtype=torch.int32).to(self.device)
1042
        num_seqs = torch.tensor([actual_num_reqs],
1043
                                dtype=torch.int32).to(self.device)
1044
1045
1046
1047
1048
        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            context_lens=context_lens,
            query_start_loc=query_start_loc,
1049
            num_seqs=num_seqs,
1050
        )
1051

1052
1053
1054
1055
        if self.is_multimodal_model:
            torch._dynamo.mark_dynamic(inputs_embeds, 0)
        else:
            torch._dynamo.mark_dynamic(input_ids, 0)
1056
1057
        torch._dynamo.mark_dynamic(position_ids, 0)
        torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
1058

1059
1060
1061
1062
1063
1064
1065
        layer_names = get_layers_from_vllm_config(self.vllm_config,
                                                  Attention).keys()
        per_layer_attn_metadata = {
            layer_name: attn_metadata
            for layer_name in layer_names
        }

1066
        with self.maybe_select_dummy_loras(
1067
1068
1069
                self.lora_config,
                np.array([num_tokens], dtype=np.int32)), set_forward_context(
                    per_layer_attn_metadata, self.vllm_config, 0):
1070
1071
1072
1073
            out = self.model(input_ids=input_ids,
                             positions=position_ids,
                             inputs_embeds=inputs_embeds)
        self._hidden_states_dtype = out.dtype
1074

1075
1076
1077
1078
1079
1080
1081
    def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
                          lora_requests) -> None:
        xm.mark_step()  # Captures input updates
        super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
                                  lora_requests)
        xm.mark_step()  # Captures metadata updates

1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
    def _precompile_mm_encoder(self) -> None:
        # Pre-compile MM encoder for all supported data modalities.
        hf_config = self.vllm_config.model_config.hf_config
        for mode, max_items_by_mode in \
            self.max_num_mm_items_by_modality.items():
            logger.info(
                "Compiling Multimodal %s Encoder with different input"
                " shapes.", mode)
            start = time.perf_counter()
            # No padding for MM encoder just yet.
            for num_items in range(1, max_items_by_mode + 1):
                logger.info("  -- mode: %s items: %d", mode, num_items)
                batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                    mode, num_items)
                # Run multimodal encoder.
                xm.mark_step()
                mm_embeds = self.model.\
                    get_multimodal_embeddings(**batched_dummy_mm_inputs)
                xm.mark_step()
                num_patches = mm_embeds[0].shape[0]
                items_size = num_patches * num_items

                # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm
                # embeddings are present. We assume `--disable-mm-chunked`,
                # hence only whole items can be scheduled. This implies we just
                # need to compile when `num_items` fit the (padded) `input_ids`
                for num_tokens in self.num_tokens_paddings:
                    if num_tokens >= items_size:
                        # XLA Workaround: if torch.zeros(..device) is used, XLA
                        # compiles a scalar+expansion op, which won't match
                        # the graph generated at runtime. CPU->TPU must be used
                        placeholders_ids = torch.zeros(num_tokens,
                                                       dtype=torch.int32,
                                                       device="cpu")
                        # Align placeholders and actual num mm_embeddings.
                        placeholders_ids[:items_size] = \
                            hf_config.image_token_index

                        placeholders_ids = placeholders_ids.to(self.device)
                        # Assign outputs or the graph will be cut short.
                        a, b = self._get_model_inputs(placeholders_ids,
                                                      [mm_embeds])
                        assert a is None
                        xm.mark_step()

            # Pre-compile `get_input_embeddings` when mm_embeddings are not
            # present. Chunk is only made of text, no mm_placeholders.
            for num_tokens in self.num_tokens_paddings:
                placeholders_ids = torch.zeros(num_tokens,
                                               dtype=torch.int32,
                                               device="cpu")
                placeholders_ids = placeholders_ids.to(self.device)
                a, b = self._get_model_inputs(placeholders_ids, [])
                assert a is None
                xm.mark_step()

            xm.wait_device_ops()
            end = time.perf_counter()
            logger.info(
                "Multimodal %s Encoder compilation finished in in %.2f "
                "[secs].", mode, end - start)

1144
    def _precompile_backbone(self) -> None:
1145
1146
        logger.info("Compiling the model with different input shapes.")
        start = time.perf_counter()
1147
        for num_tokens in self.num_tokens_paddings:
1148
            logger.info("  -- num_tokens: %d", num_tokens)
1149
            self._dummy_run(num_tokens)
1150
1151
        xm.wait_device_ops()
        end = time.perf_counter()
1152
        logger.info("Compilation finished in %.2f [secs].", end - start)
1153
        self._update_num_xla_graphs("model backbone")
1154

1155
1156
1157
1158
1159
    def _precompile_select_hidden_states(self) -> None:
        # Compile hidden state selection function for bucketed
        # n_tokens x max_num_reqs. Graph is really small so this is fine.
        logger.info(
            "Compiling select_hidden_states with different input shapes.")
1160
1161
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
1162
        for num_tokens in self.num_tokens_paddings:
1163
1164
            dummy_hidden = torch.zeros((num_tokens, hsize),
                                       device=self.device,
1165
                                       dtype=self._hidden_states_dtype)
1166
1167
1168
1169
1170
1171
1172
            torch._dynamo.mark_dynamic(dummy_hidden, 0)
            for num_reqs in self.num_reqs_paddings:
                indices = torch.zeros(num_reqs,
                                      dtype=torch.int32,
                                      device=self.device)
                torch._dynamo.mark_dynamic(indices, 0)
                self.select_hidden_states(dummy_hidden, indices)
1173
1174
1175
1176
1177
1178
                logger.info("  -- num_tokens: %d, num_seqs: %d", num_tokens,
                            num_reqs)
                # Requests can't be more than tokens. But do compile for the
                # next bigger value in case num_tokens uses bucketed padding.
                if num_reqs >= min(num_tokens, self.max_num_reqs):
                    break
1179
        xm.wait_device_ops()
1180
        end = time.perf_counter()
1181
        logger.info("Compilation finished in %.2f [secs].", end - start)
1182
        self._update_num_xla_graphs("select_hidden_states")
1183

1184
1185
    def _precompile_compute_logits(self) -> None:
        logger.info("Compiling compute_logits with different input shapes.")
1186
1187
1188
1189
1190
1191
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
        for num_reqs in self.num_reqs_paddings:
            dummy_hidden = torch.zeros((num_reqs, hsize),
                                       device=self.device,
                                       dtype=self._hidden_states_dtype)
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
            torch._dynamo.mark_dynamic(dummy_hidden, 0)
            self.compute_logits(dummy_hidden)
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("compute_logits")

    def _precompile_structured_decoding(self) -> None:
        logger.info(
            "Compiling structured_decoding with different input shapes.")
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
            dummy_logits = torch.zeros((num_reqs, self.vocab_size),
                                       device=self.device,
                                       dtype=self._hidden_states_dtype)
            dummy_require_struct_decoding = \
                self.require_structured_out_cpu[:num_reqs].to(self.device)
            dummy_grammar_bitmask = \
                self.grammar_bitmask_cpu[:num_reqs].to(self.device)
            # The first dimension of the above 3 dummy tensors cannot be
            # mark_dynamic because some operations in structured_decode require
            # them to be static.
            arange = self.structured_decode_arange.to(self.device)
            self.structured_decode(dummy_require_struct_decoding,
                                   dummy_grammar_bitmask, dummy_logits, arange)
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("structured_decoding")

    def _precompile_sample_from_logits(self) -> None:
        logger.info(
            "Compiling sample_from_logits with different input shapes.")
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
            dummy_logits = torch.zeros((num_reqs, self.vocab_size),
                                       device=self.device,
                                       dtype=self._hidden_states_dtype)
            # The first dimension of dummy_logits cannot be mark_dynamic
            # because some operations in the sampler require it to be static.
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
            for all_greedy in [False, True]:
                generate_params_if_all_greedy = not all_greedy
                sampling_metadata = (
                    TPUSupportedSamplingMetadata.from_input_batch(
                        self.input_batch,
                        num_reqs,
                        self.device,
                        generate_params_if_all_greedy,
                    ))
                sampling_metadata.all_greedy = all_greedy
1244
1245
1246
                with self.maybe_select_dummy_loras(
                        self.lora_config, np.array([num_reqs],
                                                   dtype=np.int32)):
1247
1248
                    self.sample_from_logits_func(dummy_logits,
                                                 sampling_metadata)
1249
1250
1251
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
1252
1253
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("sample_from_logits")
1254

1255
1256
1257
1258
1259
1260
1261
1262
1263
    def _precompile_gather_logprobs(self) -> None:
        logger.info("Compiling gather_logprobs with different input shapes.")
        start = time.perf_counter()
        for num_reqs in self.num_reqs_paddings:
            dummy_logits = torch.zeros((num_reqs, self.vocab_size),
                                       device=self.device,
                                       dtype=self._hidden_states_dtype)
            dummy_tokens = torch.zeros((num_reqs, 1),
                                       dtype=torch.int64).to(self.device)
1264
1265
1266
            with self.maybe_select_dummy_loras(
                    self.lora_config, np.array([num_reqs], dtype=np.int32)):
                self.gather_logprobs(dummy_logits, dummy_tokens)
1267
1268
1269
1270
1271
1272
            logger.info("  -- num_seqs: %d", num_reqs)
        xm.wait_device_ops()
        end = time.perf_counter()
        logger.info("Compilation finished in %.2f [secs].", end - start)
        self._update_num_xla_graphs("gather_logprobs")

1273
1274
1275
1276
    def capture_model(self) -> None:
        """
        Precompile all the subgraphs with possible input shapes.
        """
1277
1278
1279
1280
1281
1282
1283
1284
        with self.maybe_setup_dummy_loras(self.lora_config):
            self._precompile_mm_encoder()
            self._precompile_backbone()
            self._precompile_select_hidden_states()
            self._precompile_compute_logits()
            self._precompile_structured_decoding()
            self._precompile_sample_from_logits()
            self._precompile_gather_logprobs()
1285

1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
    def profile_run(
        self,
        num_tokens: int,
    ) -> None:
        # Profile with multimodal encoder & encoder cache.
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):

            # NOTE: Currently model is profiled with a single non-text
            # modality with the max possible input tokens even when
            # it supports multiple.
            dummy_data_modality, max_num_mm_items = max(
                self.max_num_mm_items_by_modality.items(), key=lambda t: t[1])

            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            logger.info(
                "Encoder cache will be initialized with a budget of %d tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
            batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                dummy_data_modality, max_num_mm_items)

            # Run multimodal encoder.
            # Isolate encoder graph from post-processing to minimize
            # impact of recompilation until it's fixed.
            start = time.perf_counter()
            xm.mark_step()
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
            xm.mark_step()
            xm.wait_device_ops()
            end = time.perf_counter()
            logger.info(
                "Multimodal Encoder profiling finished in in %.2f [secs].",
                end - start)

            assert len(dummy_encoder_outputs) == max_num_mm_items, (
                "Expected dimension 0 of encoder outputs to match the number "
                f"of multimodal data items: {max_num_mm_items}, got "
                f"{len(dummy_encoder_outputs)=} instead. This is most likely "
                "due to the 'get_multimodal_embeddings' method of the model "
                "not implemented correctly.")

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

        # Trigger compilation for general shape.
        self._dummy_run(num_tokens)

        xm.mark_step()
        xm.wait_device_ops()
        self.encoder_cache.clear()
        gc.collect()

1345
1346
1347
1348
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1349
            kv_cache_config: Configuration for the KV cache, including the KV
1350
1351
            cache size of each layer
        """
1352
        if len(kv_cache_config.kv_cache_groups) > 1:
1353
1354
1355
1356
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

1357
1358
1359
1360
1361
1362
1363
1364
1365
        if kv_cache_config.kv_cache_groups[
                0].kv_cache_spec.block_size != self.block_size:
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
                max_model_len=self.max_model_len,
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
1366
1367
1368
                block_sizes=[
                    kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
                ],
1369
1370
            )
        # Verify dtype compatibility between block_table_cpu and input_batch
1371
1372
1373
        assert self.block_table_cpu.dtype == self.input_batch.block_table[
            0].get_cpu_tensor().dtype

1374
1375
1376
1377
1378
1379
        kv_cache_sizes = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
            assert len(kv_cache_tensor.shared_by) == 1, (
                "KV cache tensor shared by multiple layers is not supported in "
                "TPU.")
            kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
1380

1381
        kv_caches: dict[str, torch.Tensor] = {}
1382
1383
1384
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
1385
1386
1387
                tensor_size = kv_cache_sizes[layer_name]
                assert tensor_size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_size // kv_cache_spec.page_size_bytes  # noqa
1388
                if isinstance(kv_cache_spec, AttentionSpec):
1389
1390
1391
1392
1393
1394
1395
1396
1397
                    if self.use_spmd:
                        num_kv_heads = kv_cache_spec.num_kv_heads
                        assert self.original_parallel_config is not None
                        tp_size = \
                            self.original_parallel_config.tensor_parallel_size
                        # TODO: Handle kv cache duplication under SPMD mode.
                        assert num_kv_heads % tp_size == 0, (
                            f"num_kv_heads {num_kv_heads} must be divisible by "
                            f"tp_size {tp_size} under SPMD mode")
1398
1399
1400
1401
1402
                    kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype

1403
                    tpu_kv_cache = torch.zeros(kv_cache_shape,
1404
                                               dtype=dtype).to(self.device)
1405

1406
                    kv_caches[layer_name] = tpu_kv_cache
1407
1408
                else:
                    raise NotImplementedError
1409

1410
1411
1412
1413
1414
1415
1416
1417
1418
        # Setup `kv_cache_config` and `kv_caches` for models
        # with cross-layer KV sharing
        if self.shared_kv_cache_layers:
            initialize_kv_cache_for_kv_sharing(
                self.shared_kv_cache_layers,
                kv_cache_config.kv_cache_groups,
                kv_caches,
            )

1419
1420
1421
1422
1423
        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
            self.kv_caches)

1424
1425
1426
1427
1428
        if self.use_spmd:
            # Shard KV Cache
            for cache in self.kv_caches:
                xs.mark_sharding(cache, self.mesh, (None, 'x', None, None))

1429
1430
    def reset_dynamo_cache(self):
        if self.is_multimodal_model:
1431
            compiled_model = self.model.get_language_model().model
1432
1433
1434
1435
1436
1437
1438
        else:
            compiled_model = self.model.model
        if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
            logger.info("Clear dynamo cache and cached dynamo bytecode.")
            torch._dynamo.eval_frame.remove_from_cache(
                compiled_model.original_code_object)
            compiled_model.compiled_codes.clear()
1439

1440
1441
1442
1443
1444
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
    def select_hidden_states(self, hidden_states, indices_do_sample):
        return hidden_states[indices_do_sample]

    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1445
1446
1447
1448
    def compute_logits(self,
                       sample_hidden_states: torch.Tensor) -> torch.Tensor:
        return self.model.compute_logits(sample_hidden_states, None)

1449
1450
1451
    # TODO: Under SPMD mode, sample_from_logits has correctness issue.
    #       Re-enable the torch.compile once the issue is fixed in torchxla.
    # @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
1452
1453
1454
    def sample_from_logits(
            self, logits: torch.Tensor,
            sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
1455
1456
1457
1458
        """
        Sample with xla-friendly function. This function is to be traced 
        separately from `forward` for lighter compilation overhead.
        """
1459
1460
1461
1462
1463
        if sampling_metadata.all_greedy:
            out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            out_tokens = self.sampler(logits,
                                      sampling_metadata).sampled_token_ids
1464
1465
        return out_tokens

1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
    def gather_logprobs(self, logits: torch.Tensor,
                        sampled_tokens: torch.Tensor) -> LogprobsTensors:
        """
        Gather the top_logprobs with corresponding tokens. Use a fixed number
        of logprobs as an alternative to having multiple pre-compiled graphs.
        Select the number of logprobs actually demanded by each request on CPU.
        """
        logprobs = self.sampler.compute_logprobs(logits)
        return self.sampler.gather_logprobs(
            logprobs,
            self.model_config.max_logprobs,
            token_ids=sampled_tokens.squeeze(-1))

1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
    @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
    def structured_decode(self, require_struct_decoding: torch.Tensor,
                          grammar_bitmask: torch.Tensor, logits: torch.Tensor,
                          arange: torch.Tensor) -> torch.Tensor:
        return torch.where(
            require_struct_decoding,
            self.apply_grammar_bitmask(logits, grammar_bitmask, arange),
            logits)

    def apply_grammar_bitmask(self, logits: torch.Tensor,
                              grammar_bitmask: torch.Tensor,
                              arange: torch.Tensor):
        assert (logits.shape[0] == grammar_bitmask.shape[0])
        logits_cloned = logits.clone()
        for i in range(logits.shape[0]):
            unpacked_bitmask = (torch.bitwise_right_shift(
                grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0
            unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size]
            logits_cloned[i] = logits_cloned[i].masked_fill(
                unpacked_bitmask, -float("inf"))
        return logits_cloned

1502
1503
    def get_multimodal_embeddings(self, *args, **kwargs):
        return self.model.get_multimodal_embeddings(*args, **kwargs)
1504

1505
1506
1507
    def get_input_embeddings(self, *args, **kwargs):
        return self.model.get_input_embeddings(*args, **kwargs)

1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
    def prepare_structured_decoding_input(
        self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        grammar_bitmask = scheduler_output.grammar_bitmask
        assert grammar_bitmask is not None
        num_reqs, _ = logits.shape

        # Reset pre-allocated tensors
        self.grammar_bitmask_cpu.zero_()
        self.require_structured_out_cpu.zero_()

        # We receive the structured output bitmask from the scheduler, but the
        # indices of the requests in the batch may not match the indices of
        # the bitmask since the scheduler doesn't know how the tpu runner is
        # ordering the requests in the batch. We need to match the order of
        # bitmask with the order of requests
        struct_out_indices: list[int] = []
        mask_indices: list[int] = []
        for req_id in self.input_batch.req_ids:
            mask_index = scheduler_output.structured_output_request_ids.get(
                req_id)
            if mask_index is None:
                continue
            batch_index = self.input_batch.req_id_to_index[req_id]
            struct_out_indices.append(batch_index)
            mask_indices.append(mask_index)
        self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy(
            grammar_bitmask[mask_indices])
        # It's not guaranteed that all requests in this batch require
        # structured output, so create a bool tensor to represent
        # the requests that need structured output.
        struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
        self.require_structured_out_cpu[struct_out_indices] = True
        return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
            self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
            self.structured_decode_arange.to(logits.device)

1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
    def _get_mm_dummy_batch(self, modality: str,
                            batch_size: int) -> BatchedTensorInputs:
        # Dummy data for pre-compiling multimodal models.
        dummy_request_data = self.mm_registry.get_decoder_dummy_data(
            model_config=self.model_config,
            seq_len=self.max_num_tokens,
        )
        dummy_mm_data = dummy_request_data.multi_modal_data

        # Dummy data definition in V0 may contain multiple multimodal items
        # (e.g, multiple images) for a single request, therefore here we
        # always replicate first item by max_num_mm_items times since in V1
        # they are scheduled to be processed separately.
        assert isinstance(dummy_mm_data, MultiModalKwargs), (
            "Expected dummy multimodal data to be of type "
            f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. "
            "This is most likely due to the model not having a merged "
            "processor.")

        # When models have a merged processor, their dummy data is
        # already batched `MultiModalKwargs`, therefore we take the first
        # `MultiModalKwargsItem` from the desired modality to profile on.
        dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
        dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])

        batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
                                                         batch_size)
1572
1573
1574
1575
        return MultiModalKwargs.as_kwargs(
            batched_dummy_mm_inputs,
            device=self.device,
        )
1576

1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588

def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
    logger.info("Preparing request paddings:")
    # assert min_req_size is power of 2
    assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
    paddings: list = []
    num = max(MIN_NUM_SEQS, min_req_size)
    while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
        paddings.append(num)
        logger.info("    %d", num)
        num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
    return paddings
1589
1590


1591
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
1592
    res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
1593
    return min(res, upper_limit)
1594
1595


1596
1597
def _get_token_paddings(min_token_size: int, max_token_size: int,
                        padding_gap: int) -> list[int]:
1598
1599
    """Generate a list of padding size, starting from min_token_size, 
    ending with a number that can cover max_token_size
1600

1601
1602
1603
    If padding_gap == 0 then:
        increase 2X each time (exponential)
    else:
1604
        first increase the size to twice,
1605
        then increase the padding size by padding_gap.
1606
    """
1607
1608
    # assert min_token_size is power of 2
    assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
1609
1610
    paddings = []
    num = min_token_size
1611
1612

    if padding_gap == 0:
1613
        logger.info("Using exponential token paddings:")
1614
        while True:
1615
1616
            logger.info("    %d", num)
            paddings.append(num)
1617
1618
            if num >= max_token_size:
                break
1619
1620
            num *= 2
    else:
1621
        logger.info("Using incremental token paddings:")
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
        while num <= padding_gap:
            logger.info("    %d", num)
            paddings.append(num)
            num *= 2
        num //= 2
        while num < max_token_size:
            num += padding_gap
            logger.info("    %d", num)
            paddings.append(num)

1632
1633
1634
1635
1636
1637
1638
1639
1640
    return paddings


def _get_padded_token_len(paddings: list[int], x: int) -> int:
    """Return the first element in paddings list greater or equal to x.
    """
    index = bisect.bisect_left(paddings, x)
    assert index < len(paddings)
    return paddings[index]
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669


def replace_set_lora(model):

    def _tpu_set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
        bias: Optional[torch.Tensor] = None,
    ):
        # TODO: The integer index leads to a recompilation, but converting it
        # to a tensor doesn't seem to work anymore. This might be fixed with a
        # later release of torch_xla.
        self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
        xm.mark_step()

    def _tpu_reset_lora(self, index: int):
        self._original_reset_lora(index)
        xm.mark_step()

    for _, module in model.named_modules():
        if isinstance(module, BaseLayerWithLoRA):
            module._original_set_lora = module.set_lora
            module._original_reset_lora = module.reset_lora
            module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
            module.reset_lora = _tpu_reset_lora.__get__(
                module, module.__class__)