tpu_model_runner.py 46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import bisect
3
import time
4
from typing import TYPE_CHECKING, Optional, cast
5
6
7
8
9
10
11
12
13
14
from unittest.mock import patch

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

15
import vllm.envs as envs
16
17
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
18
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
19
from vllm.config import VllmConfig
20
from vllm.forward_context import set_forward_context
21
22
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
23
24
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
25
from vllm.multimodal.utils import group_mm_inputs_by_modality
26
from vllm.sampling_params import SamplingType
27
from vllm.sequence import IntermediateTensors
28
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
29
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
30
                                               PallasMetadata)
31
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
32
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
33
                                        KVCacheSpec, SlidingWindowSpec)
34
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
35
36
37
                             ModelRunnerOutput, SamplerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
38
39
40
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

41
42
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
                    scatter_mm_placeholders)
43

44
if TYPE_CHECKING:
45
    from vllm.v1.core.sched.output import SchedulerOutput
46
47
48
49
50
51

logger = init_logger(__name__)

# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
52
INVALID_TOKEN_ID = -1
53
54
# Smallest output size
MIN_NUM_SEQS = 8
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


class TPUModelRunner:

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
        self.device_config = vllm_config.device_config

        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
        self.device = device
81
        self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
82

83
        self.enforce_eager = model_config.enforce_eager
84
85
86
87

        self.num_xla_graphs = 0
        self._update_num_xla_graphs("init")

88
89
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
90
        self._hidden_states_dtype = self.dtype
91
92
93
94
95
96

        self.is_multimodal_model = model_config.is_multimodal_model
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
97
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
98
99
        # InputBatch needs to work with sampling tensors greater than padding
        # to avoid dynamic shapes. Also, avoid suboptimal alignment.
100
        self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
101
102
103
104
105
106
107
108
109
110

        # Model-related.
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
        self.hidden_size = model_config.get_hidden_size()

111
112
113
114
115
116
117
118
119
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.uses_mrope = model_config.uses_mrope
        # TODO: Support M-RoPE (e.g, Qwen2-VL)
        assert not self.uses_mrope, "TPU does not support M-RoPE yet."

        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
120
            mm_registry=self.mm_registry,
121
122
123
124
125
126
127
128
129
130
131
132
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
        self.kv_caches: list[torch.Tensor] = []
        # req_id -> (input_id -> encoder_output)
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

        # Request states.
        self.requests: dict[str, CachedRequestState] = {}
133
134
135
136
137
138
139
        # Persistent batch.
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
140
            vocab_size=model_config.get_vocab_size(),
141
142
        )

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        # Cached torch/numpy tensor
        # The pytorch tensor and numpy array share the same buffer.
        # Sometimes the numpy op is faster so we create both.
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu")
        self.input_ids_np = self.input_ids_cpu.numpy()

        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu")
        self.positions_np = self.positions_cpu.numpy()

        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int64,
                                            device="cpu")
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.block_table_cpu = torch.zeros(
161
            (self.max_num_tokens, self.max_num_blocks_per_req),
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
            device="cpu")

        self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()

        self.seq_lens_cpu = torch.zeros(self.max_num_tokens,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
176
177
178
179

        # Range tensor with values [0 .. self.max_num_tokens - 1].
        # Used to initialize positions / context_lens / seq_lens
        self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
180
181
182
183
        self.num_tokens_paddings = _get_paddings(
            min_token_size=16,
            max_token_size=self.max_num_tokens,
            padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def _update_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        total_cached_graphs = xr.get_num_cached_compilation_graph()
        new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
        if new_compiled_graphs == 0:
            return

        logger.info("Add new %d compiled XLA graphs due to %s",
                    new_compiled_graphs, case_str)
        self.num_xla_graphs += new_compiled_graphs

    def _verify_num_xla_graphs(self, case_str):
        check_comp = self.check_recompilation and not self.enforce_eager
        if not check_comp:
            return

        curr_cached_graph = xr.get_num_cached_compilation_graph()
        assert self.num_xla_graphs == curr_cached_graph, (
            "Recompilation after warm up is detected during {}."
            " num_xla_graphs = {} curr_cached_graph = {}".format(
                case_str, self.num_xla_graphs, curr_cached_graph))

210
211
212
213
214
215
216
217
    def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        Returns:
218
            True if there is a new/resumed/paused/finished request.
219
220
221
222
223
            If False, we can skip copying SamplingMetadata to the GPU.
        """
        # Remove finished requests from the cached states.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
224
            self.encoder_cache.pop(req_id, None)
225
226
227
228
229
230
231

        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
232
        removed_req_indices: list[int] = []
233
234
235
236
237
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)

238
239
240
241
242
243
244
245
        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            assert req_index is not None
            removed_req_indices.append(req_index)

263
        req_ids_to_add: list[str] = []
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        # Add new requests to the cached states.
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
                sampling_params=sampling_params,
                generator=generator,
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                output_token_ids=[],
                lora_request=new_req_data.lora_request,
            )

            req_ids_to_add.append(req_id)

        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
            req_state = self.requests[req_id]

            # Update the cached states.
            req_state.num_computed_tokens = req_data.num_computed_tokens
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
                req_data.num_computed_tokens)
316
317
            self.input_batch.block_table.append_row(req_data.new_block_ids,
                                                    req_index)
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
335

336
337
338
339
340
341
        return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0

    def get_model(self) -> nn.Module:
        assert self.model is not None
        return self.model

342
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
343
        """
344
        Generates the KVCacheSpec by parsing the kv cache format from each
345
346
        Attention module in the static forward context.
        Returns:
347
            KVCacheSpec: A dictionary mapping layer names to their KV cache
348
349
350
351
352
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
353
        kv_cache_spec: dict[str, KVCacheSpec] = {}
354
355
356
        for layer_name, attn_module in forward_ctx.items():
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=attn_module.dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=False,
                    )
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=attn_module.dtype,
                        use_mla=False,
                    )
374
375
376
377
378
379
380
381
382
383
384
385
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec

386
    def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
387
388
389
390
391
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

392
393
394
395
        # Get the number of scheduled tokens for each request.
        num_scheduled_tokens_per_req = []
        max_num_scheduled_tokens_all_reqs = 0
        for req_id in self.input_batch.req_ids[:num_reqs]:
396
            assert req_id is not None
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
            num_scheduled_tokens_per_req.append(num_tokens)
            max_num_scheduled_tokens_all_reqs = max(
                max_num_scheduled_tokens_all_reqs, num_tokens)
        num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
                                                dtype=np.int32)
        assert max_num_scheduled_tokens_all_reqs > 0

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        # For each scheduled token, what are the corresponding req index.
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens_per_req)

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # For each scheduled token, what is its position in corresponding req.
        arange = np.concatenate(
            [self.arange_np[:n] for n in num_scheduled_tokens_per_req])

        # Get positions.
        positions_np = self.positions_np[:total_num_scheduled_tokens]
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])

        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
433
434
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
                           0,
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])

        # Calculate the slot mapping.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
        # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
450
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
451
452
453
454
455
456
457
458
459
460
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])

        # Prepare the attention metadata.
        self.query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens_per_req,
                  out=self.query_start_loc_np[1:num_reqs + 1])
461
        self.query_start_loc_np[num_reqs + 1:] = 1
462
463
464
465
466
467

        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens_per_req)

        # Do the padding and copy the tensors to the TPU.
468
        padded_total_num_scheduled_tokens = _get_padded_token_len(
469
            self.num_tokens_paddings, total_num_scheduled_tokens)
470
471
472
        # Zero out to avoid spurious values from prev iteration (last cp chunk)
        self.input_ids_cpu[
            total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
473
474
475
476
477
478
479
480
481
482
        self.input_ids = self.input_ids_cpu[:
                                            padded_total_num_scheduled_tokens].to(
                                                self.device)
        self.position_ids = self.positions_cpu[:
                                               padded_total_num_scheduled_tokens].to(
                                                   self.device)
        self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID
        slot_mapping = self.slot_mapping_cpu[:
                                             padded_total_num_scheduled_tokens].to(
                                                 self.device)
483
484
        block_tables = self.block_table_cpu[:self.max_num_reqs]
        block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
485
            self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
486
487
        block_tables = block_tables.to(self.device)
        query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
488
            self.device)
489
        seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
490
491
492

        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
493
            block_tables=block_tables,
494
495
            context_lens=seq_lens,
            query_start_loc=query_start_loc,
496
497
498
            num_seqs=torch.tensor([num_reqs],
                                  dtype=torch.int32,
                                  device=self.device),
499
        )
500
501
502
503
504
        # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
        # request in the batch. While we should not sample any token from this
        # partial request, we do so for simplicity. We will ignore the sampled
        # token from the partial request.
        # TODO: Support prompt logprobs.
505
506
        padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
            num_reqs, self.max_num_reqs)
507
508
        # Indices at which we sample (positions of last token in the sequence).
        # Padded to avoid recompiling when `num_reqs` varies.
509
510
        logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
        logits_indices = logits_indices.to(self.device)
511
        return attn_metadata, logits_indices
512

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    def _scatter_placeholders(
        self,
        embeds: torch.Tensor,
        is_embed: Optional[torch.Tensor],
    ) -> torch.Tensor:
        if is_embed is None:
            return embeds

        placeholders = embeds.new_full(
            (is_embed.shape[0], embeds.shape[-1]),
            fill_value=torch.nan,
        )
        placeholders[is_embed] = embeds
        return placeholders

    def _gather_placeholders(
        self,
        placeholders: torch.Tensor,
        is_embed: Optional[torch.Tensor],
    ) -> torch.Tensor:
        if is_embed is None:
            return placeholders

        return placeholders[is_embed]

    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
539
540
541
542
543
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
544
545
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
546
547
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
548
549
550
551
552

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

579
580
581
582
583
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

584
585
586
587
            for output in curr_group_outputs:
                encoder_outputs.append(output)

        # Cache the encoder outputs.
588
589
590
591
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
592
593
594
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}

595
596
597
598
599
600
            self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
                output,
                is_embed=pos_info.is_embed,
            )

    def _gather_mm_embeddings(
601
602
603
        self,
        scheduler_output: "SchedulerOutput",
    ) -> list[torch.Tensor]:
604
        mm_embeds: list[torch.Tensor] = []
605
606
607
608
609
610
611
        for req_id in self.input_batch.req_ids:
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
612
613
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
635
636
637
638
639
640
641
642
643
644

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                mm_embeds.append(mm_embeds_item)
        return mm_embeds
645

646
647
648
649
    @torch.no_grad()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
650
        intermediate_tensors: Optional[IntermediateTensors] = None,
651
652
653
    ) -> ModelRunnerOutput:
        # Update cached state
        self._update_states(scheduler_output)
654
        if not scheduler_output.total_num_scheduled_tokens:
655
            # Return empty ModelRunnerOutput if there's no work to do.
656
            return EMPTY_MODEL_RUNNER_OUTPUT
657

658
659
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
660
661
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
662
        else:
663
            mm_embeds = []
664

665
666
        # Prepare inputs
        attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
667
668
669
670
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
671
            if mm_embeds:
672
                inputs_embeds = self.model.get_input_embeddings(
673
                    self.input_ids, mm_embeds)
674
675
676
677
678
679
680
681
682
683
            else:
                inputs_embeds = self.model.get_input_embeddings(self.input_ids)
            input_ids = None
        else:
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids
            inputs_embeds = None
684
        num_reqs = self.input_batch.num_reqs
685
686
687
        # NOTE (NickLucche) here we sync with TPU: sampling params tensors
        # are copied to device in chunks of pre-compiled padded shape to
        # avoid recompilations.
688
        tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
689
            from_input_batch(self.input_batch, logits_indices)
690
691
692
        # Run the decoder
        with set_forward_context(attn_metadata, self.vllm_config):
            hidden_states = self.model(
693
694
695
                input_ids=input_ids,
                positions=self.position_ids,
                inputs_embeds=inputs_embeds,
696
            )
697
698
        selected_token_ids = self.sample_from_hidden(hidden_states,
                                                     tpu_sampling_metadata)
699
        # Remove padding on cpu and keep dynamic op outside of xla graph.
700
        selected_token_ids = selected_token_ids.cpu()[:num_reqs]
701

702
703
        # Update the cache state concurrently. Code above will not block until
        # we use `selected_token_ids`. Add mark_step if post-processing changes
704
        request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
705
        discard_sampled_tokens_req_indices = []
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
            assert req_id is not None
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
            if seq_len >= req_state.num_tokens:
                request_seq_lens.append((i, req_state, seq_len))
            else:
                # Ignore the sampled token from the partial request.
                # Rewind the generator state as if the token was not sampled.
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)

721
722
723
724
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)

725
726
727
        assert all(
            req_id is not None for req_id in
            self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
728
        req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
729

730
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
731
        for req_id in self.input_batch.req_ids[:num_reqs]:
732
733
            prompt_logprobs_dict[req_id] = None

734
735
736
        max_gen_len = selected_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_sampled_token_ids = selected_token_ids.tolist()
737

738
739
740
741
742
743
744
            # Mask out the sampled tokens that should not be sampled.
            # TODO: Keep in sync with gpu_model_runner.py, in particular
            #       the "else" case here
            for i in discard_sampled_tokens_req_indices:
                valid_sampled_token_ids[i].clear()

            # Append sampled tokens
745
746
747
748
749
            for i, req_state, seq_len in request_seq_lens:
                token_id = valid_sampled_token_ids[i][0]
                self.input_batch.token_ids_cpu[i, seq_len] = token_id
                req_state.output_token_ids.append(token_id)
                self.input_batch.num_tokens[i] += 1
750

751
752
753
754
755
756
757
758
759
760
761
762
763
764
        else:
            valid_mask = selected_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
            valid_sampled_token_ids = [
                seq.tolist()
                for seq in selected_token_ids[valid_mask].split(gen_lens)
            ]
            self.input_batch.num_tokens[:num_reqs] += gen_lens
            for i, req_state, seq_len in request_seq_lens:
                target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
                self.input_batch.token_ids_cpu[
                    i, target_slice] = valid_sampled_token_ids[i]
                req_state.output_token_ids.extend(valid_sampled_token_ids[i])

765
        model_runner_output = ModelRunnerOutput(
766
            req_ids=req_ids,
767
            req_id_to_index=self.input_batch.req_id_to_index,
768
            sampled_token_ids=valid_sampled_token_ids,
769
            spec_token_ids=None,
770
            logprobs=None,
771
            prompt_logprobs_dict=prompt_logprobs_dict,
772
        )
773
774
775
776
777

        # Check there are no new graphs compiled - all the graphs should be
        # captured and compiled during warm up.
        self._verify_num_xla_graphs("execute_model")

778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
        return model_runner_output

    def load_model(self) -> None:
        self.device = self.device_config.device

        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
        xm_tp_rank = xr.global_ordinal()
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
            model = get_model(vllm_config=self.vllm_config)
798
799
        # Sync all pending XLA execution during model initialization and weight
        # loading.
800
801
        xm.mark_step()
        xm.wait_device_ops()
802
803
        self.model = model
        self.sampler = TPUSampler()
804

805
    @torch.no_grad()
806
    def _dummy_run(self, num_tokens: int) -> None:
807
808
809
810
811
812
813
814
815
816
        if self.is_multimodal_model:
            input_ids = None
            inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
                                        dtype=self.dtype,
                                        device=self.device)
        else:
            input_ids = torch.zeros((num_tokens),
                                    dtype=torch.int32,
                                    device=self.device)
            inputs_embeds = None
817
        actual_num_reqs = min(num_tokens, self.max_num_reqs)
818
819
820
821
822
823
        position_ids = torch.zeros(num_tokens,
                                   dtype=torch.int32,
                                   device=self.device)
        slot_mapping = torch.zeros(num_tokens,
                                   dtype=torch.int64,
                                   device=self.device)
824
825
826
827
828
        block_tables = torch.zeros(
            (self.max_num_reqs, self.block_table_cpu.shape[1]),
            dtype=torch.int32,
            device=self.device)
        query_lens = [1] * self.max_num_reqs
829
830
831
832
        query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
                                                    dtype=torch.int32),
                                       dim=0,
                                       dtype=torch.int32).to(self.device)
833
        context_lens = torch.ones((self.max_num_reqs, ),
834
835
                                  dtype=torch.int32,
                                  device=self.device)
836
837
838
        num_seqs = torch.tensor([actual_num_reqs],
                                dtype=torch.int32,
                                device=self.device)
839
840
841
842
843
        attn_metadata = PallasMetadata(
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            context_lens=context_lens,
            query_start_loc=query_start_loc,
844
            num_seqs=num_seqs,
845
        )
846

847
848
849
850
        if self.is_multimodal_model:
            torch._dynamo.mark_dynamic(inputs_embeds, 0)
        else:
            torch._dynamo.mark_dynamic(input_ids, 0)
851
852
        torch._dynamo.mark_dynamic(position_ids, 0)
        torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
853
854

        with set_forward_context(attn_metadata, self.vllm_config, 0):
855
856
857
858
            out = self.model(input_ids=input_ids,
                             positions=position_ids,
                             inputs_embeds=inputs_embeds)
        self._hidden_states_dtype = out.dtype
859
860
861
862

    def capture_model(self) -> None:
        """Compile the model."""

863
864
865
        logger.info("Compiling the model with different input shapes.")

        start = time.perf_counter()
866
        for num_tokens in self.num_tokens_paddings:
867
            logger.info("  -- num_tokens: %d", num_tokens)
868
            self._dummy_run(num_tokens)
869
            xm.mark_step()
870
871
        xm.wait_device_ops()
        end = time.perf_counter()
872

873
        logger.info("Compilation finished in in %.2f [secs].", end - start)
874
        self._update_num_xla_graphs("model")
875
876
877
878
879
880
881

        logger.info("Compiling sampling with different input shapes.")
        start = time.perf_counter()
        hsize = self.model_config.get_hidden_size()
        device = self.device
        # Compile sampling step for different model+sampler outputs in bucketed
        # n_tokens x max_num_reqs. Graph is really small so this is fine.
882
        for num_tokens in self.num_tokens_paddings:
883
884
885
            num_reqs_to_sample = MIN_NUM_SEQS
            dummy_hidden = torch.randn((num_tokens, hsize),
                                       device=device,
886
                                       dtype=self._hidden_states_dtype)
887
            # Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
888
889
890
891
892
893
            while True:
                indices = torch.zeros(
                    num_reqs_to_sample,
                    dtype=torch.int32,
                    device=device,
                )
894
                xm.mark_step()
895
                sampling_meta = TPUSupportedSamplingMetadata.\
896
                    from_input_batch(self.input_batch, indices)
897
898
                logger.info("  -- num_tokens: %d, num_seqs: %d", num_tokens,
                            num_reqs_to_sample)
899
                out = self.sample_from_hidden(dummy_hidden, sampling_meta)
900
                out = out.cpu()
901
902
903
                # Requests can't be more than tokens. But do compile for the
                # next bigger value in case num_tokens uses bucketed padding.
                if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
904
                    break
905
906
907
                # Make sure to compile the `max_num_reqs` upper-limit case
                num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
                    num_reqs_to_sample + 1, self.max_num_reqs)
908
        xm.wait_device_ops()
909
        end = time.perf_counter()
910
911
912

        logger.info("Compilation finished in in %.2f [secs].", end - start)
        self._update_num_xla_graphs("sampling")
913
914
915
916
917

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
918
            kv_cache_config: Configuration for the KV cache, including the KV
919
920
            cache size of each layer
        """
921
        if len(kv_cache_config.kv_cache_groups) > 1:
922
923
924
925
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

926
        kv_caches: dict[str, torch.Tensor] = {}
927

928
929
930
931
932
933
934
935
936
937
938
939
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
                tensor_config = kv_cache_config.tensors[layer_name]
                assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
                if isinstance(kv_cache_spec, FullAttentionSpec):
                    kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype

940
941
942
                    tpu_kv_cache = torch.zeros(kv_cache_shape,
                                               dtype=dtype,
                                               device=self.device)
943

944
                    kv_caches[layer_name] = tpu_kv_cache
945
946
                else:
                    raise NotImplementedError
947
948
949
950
951
952

        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
            self.kv_caches)

953
954
955
956
957
958
959
960
961
962
963
    def reset_dynamo_cache(self):
        if self.is_multimodal_model:
            assert hasattr(self.model, "language_model")
            compiled_model = self.model.language_model.model
        else:
            compiled_model = self.model.model
        if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
            logger.info("Clear dynamo cache and cached dynamo bytecode.")
            torch._dynamo.eval_frame.remove_from_cache(
                compiled_model.original_code_object)
            compiled_model.compiled_codes.clear()
964

965
    def sample_from_hidden(
966
967
        self,
        hidden_states: torch.Tensor,
968
969
970
        sampling_metadata: TPUSupportedSamplingMetadata,
    ) -> torch.Tensor:
        """
971
972
973
            Sample with xla-friendly function. This function is to be traced 
            separately for lighter compilation overhead.
            """
974
975
976
        # Tensor `sample_hidden_states` is of fixed pre-compiled size.
        sample_hidden_states = \
            hidden_states[sampling_metadata.indices_do_sample]
977
978
979
980
981
982
983
984
985
986
        # SamplingMetadata here for pruning output in LogitsProcessor, disabled.
        logits = self.model.compute_logits(sample_hidden_states, None)

        def sample(
                logits: torch.Tensor,
                sampling_metadata: TPUSupportedSamplingMetadata
        ) -> SamplerOutput:
            sampler_out = self.sampler(logits, sampling_metadata)
            return sampler_out

987
988
        # Optimized greedy sampling branch, tracing both paths in a single pass
        # NOTE all_greedy is a scalar, this is just an optimized if/else.
989
990
991
992
        out_tokens = torch.where(
            sampling_metadata.all_greedy,
            torch.argmax(logits, dim=-1, keepdim=True),
            sample(logits, sampling_metadata).sampled_token_ids)
993
994
        return out_tokens

995

996
997
def _get_padded_number(n: int, multiple: int) -> int:
    return ((n + multiple - 1) // multiple) * multiple
998
999


1000
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
1001
    res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
1002
    return min(res, upper_limit)
1003
1004
1005
1006
1007
1008


def _get_paddings(min_token_size: int, max_token_size: int,
                  padding_gap: int) -> list[int]:
    """Generate a list of padding size, starting from min_token_size, 
    ending with a number that can cover max_token_size
1009
1010
1011
1012
1013
1014
    
    If padding_gap == 0 then:
        increase 2X each time (exponential)
    else:
        first increase the size to twice, 
        then increase the padding size by padding_gap.
1015
1016
1017
    """
    paddings = []
    num = min_token_size
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037

    if padding_gap == 0:
        logger.info("Using exponential paddings:")
        while num <= max_token_size:
            logger.info("    %d", num)
            paddings.append(num)
            num *= 2

    else:
        logger.info("Using incremental paddings:")
        while num <= padding_gap:
            logger.info("    %d", num)
            paddings.append(num)
            num *= 2
        num //= 2
        while num < max_token_size:
            num += padding_gap
            logger.info("    %d", num)
            paddings.append(num)

1038
1039
1040
1041
1042
1043
1044
1045
1046
    return paddings


def _get_padded_token_len(paddings: list[int], x: int) -> int:
    """Return the first element in paddings list greater or equal to x.
    """
    index = bisect.bisect_left(paddings, x)
    assert index < len(paddings)
    return paddings[index]