model_runner.py 50.4 KB
Newer Older
1
import contextlib
2
import time
3
4
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
5

6
import numpy as np
7
import torch
8
import torch.nn as nn
9

10
11
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
                            get_attn_backend)
12
13
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VisionLanguageConfig)
14
15
16
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
                                                   pynccl_utils)
17
from vllm.logger import init_logger
18
19
20
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
21
from vllm.model_executor import SamplingMetadata
22
from vllm.model_executor.model_loader import get_model
23
from vllm.sampling_params import SamplingParams, SamplingType
24
25
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
26
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
27
28
                        is_pin_memory_available, make_tensor_with_pad,
                        maybe_expand_dim)
29
30
31
32

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
33
LORA_WARMUP_RANK = 8
34
35
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
38
39
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
40
41


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
class PreparePromptMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadataPerStage]
    prompt_lens: List[int]
    subquery_lens: List[int]
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    multi_modal_input: Optional[torch.Tensor]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PreparePromptMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
            prompt_lens=[],
            subquery_lens=[],
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            multi_modal_input=None,
            slot_mapping=[],
        )


class PrepareDecodeMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadata]
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PrepareDecodeMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            slot_mapping=[],
        )


# How batches are constructed.
class BatchType(IntEnum):
    # Every batch is prefill.
    PREFILL = 0
    # Every batch is decode.
    DECODE = 1
    # Batch is a mixture of prefill and decode.
    MIXED = 2


102
103
104
105
106
107
108
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
109
        device_config: DeviceConfig,
110
        load_config: LoadConfig,
111
        lora_config: Optional[LoRAConfig],
112
        kv_cache_dtype: Optional[str] = "auto",
113
        is_driver_worker: bool = False,
114
        vision_language_config: Optional[VisionLanguageConfig] = None,
115
116
117
118
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
119
        self.lora_config = lora_config
120
        self.load_config = load_config
121
        self.is_driver_worker = is_driver_worker
122

Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
125
126
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
127
128
129
130
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

131
132
        # Set after load_model.
        self.lora_manager: LRUCacheWorkerLoRAManager = None
133

134
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
135
136
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
137
138
139
140

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
141

142
        self.pin_memory = is_pin_memory_available()
143
        self.kv_cache_dtype = kv_cache_dtype
144
        self.vision_language_config = vision_language_config
145

146
147
148
        self.attn_backend = get_attn_backend(
            self.model_config.dtype if model_config is not None else None)

149
150
151
152
153
154
155
156
157
158
159
        # Lazy initialization
        self.model: torch.nn.Module  # Set after load_model
        self.block_size: int  # Set after initial profiling.
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables: torch.Tensor  # Set after initial profiling.

160
    def load_model(self) -> None:
161
        with CudaMemoryProfiler() as m:
162
            self.model = get_model(
163
164
165
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
166
167
168
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
169
170
                scheduler_config=self.scheduler_config,
            )
171
172

        self.model_memory_usage = m.consumed_memory
173
174
        logger.info(f"Loading model weights took "
                    f"{self.model_memory_usage / float(2**30):.4f} GB")
175
176

        if self.lora_config:
177
178
179
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
180
181
182
183
184
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
185
186
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
187
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
188
189
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
190
            self.model = self.lora_manager.create_lora_manager(self.model)
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
                    raise RuntimeError("Using FP8 KV cache and scaling "
                                       "factors provided but model "
                                       f"{self.model.__class__} does not "
                                       "support loading scaling factors.")
            else:
                logger.warn("Using FP8 KV cache but no scaling factors "
                            "provided. Defaulting to scaling factors of 1.0. "
                            "This may lead to less accurate results!")
        elif self.model_config.quantization_param_path is not None:
            logger.warn("KV cache scaling factors provided, "
                        "but the KV cache data type is not FP8. "
                        "KV cache scaling factors will not be used.")

212
213
214
    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

215
        self.graph_block_tables = np.zeros(
216
217
218
219
220
221
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
        return (self.max_context_len_to_capture + block_size - 1) // block_size
222

223
224
225
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
226
    ) -> PreparePromptMetadata:
227
228
229
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
230
231
232
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
233
234

        prompt_lens: List[int] = []
235
236
237
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
238
        multi_modal_input_list: List[torch.Tensor] = []
239

240
241
242
        if len(seq_group_metadata_list) == 0:
            return PreparePromptMetadata.empty()

243
244
245
246
247
248
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

249
250
251
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
252
253
                    and not (computed_block_nums is None
                             or computed_block_nums == [])):
254
255
256
257
258
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
259
            seq_data = seq_group_metadata.seq_data[seq_id]
260
261
262
263
264
265
            computed_len = seq_data.get_num_computed_tokens()
            # We should use get_len here because in case of preemption
            # it contains output tokens.
            prefill_end = min(seq_data.get_len(),
                              computed_len + token_chunk_size)
            prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
266
            prompt_len = prefill_end
267
            prompt_lens.append(prompt_len)
268
269
270
271
272
273
274
275

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
                computed_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[computed_len:]
                prefix_block_tables.append(computed_block_nums)
276
277
278
279
280
281
282
283
            elif self.scheduler_config.chunked_prefill_enabled:
                if seq_group_metadata.block_tables is not None:
                    # Prefill has chunked before.
                    block_table = seq_group_metadata.block_tables[seq_id]
                    prefix_block_tables.append(block_table)
                else:
                    # The first prefill.
                    prefix_block_tables.append([])
284
285
            else:
                prefix_block_tables.append([])
286
287
288
289
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
                assert computed_len == 0

290
            # actual prompt lens
291
            context_lens.append(computed_len)
292
            subquery_lens.append(prompt_len - computed_len)
293

294
            input_tokens.extend(prompt_tokens)
295
296
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
297
            input_positions.extend(list(range(computed_len, prefill_end)))
298
299
300
301
302
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

303
            lora_index_mapping += [lora_id] * (prompt_len - computed_len)
304
305
            lora_prompt_mapping.extend(
                [lora_id] *
306
                (prompt_len - computed_len
307
308
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

309
310
311
312
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

313
314
315
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
316
                slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
317
318
319
320
321
322
323
324
325
326
327
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
328
                assert computed_len == 0, (
329
330
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
331
                start_idx = max(0, prompt_len - self.sliding_window)
332
333

            for i in range(computed_len, prefill_end):
334
                if i < start_idx:
335
                    slot_mapping.append(_PAD_SLOT_ID)
336
337
338
339
340
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
341
342
343
                slot_mapping.append(slot)

        max_subquery_len = max(subquery_lens)
344
        max_prompt_len = max(prompt_lens)
345
346
        assert max_subquery_len > 0

347
348
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
349
                                           device=self.device)
350
351
352
353
354
355
356
357
358
359

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

360
361
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
362
        block_tables = make_tensor_with_pad(
363
364
365
366
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
367
            device=self.device,
368
        )
369
370
371
372
373
374
375
376
377
378

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
        subquery_lens_tensor = torch.tensor(subquery_lens,
                                            dtype=torch.long,
                                            device=self.device)
        subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
                                         dtype=torch.int32,
                                         device=self.device)

379
380
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
381
                                          device=self.device)
382
383
384
385
386
387
388
389
390
391
392
393
394
        seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(subquery_lens_tensor,
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

        torch.cumsum(prompt_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
395

396
        attn_metadata = self.attn_backend.make_metadata(
397
            is_prompt=True,
398
399
400
            prompt_lens=prompt_lens,
            prompt_lens_tensor=prompt_lens_tensor,
            max_subquery_len=max_subquery_len,
401
            max_context_len=None,
402
            max_prompt_len=max_prompt_len,
403
404
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
405
406
            context_lens=context_lens_tensor,
            block_tables=block_tables,
407
            use_cuda_graph=False,
408
        )
409
410
411
412
413
414
415
416
417
418
419
420
421

        return PreparePromptMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            prompt_lens=prompt_lens,
            subquery_lens=subquery_lens,
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            multi_modal_input=multi_modal_input,
            slot_mapping=slot_mapping,
        )
422
423
424
425

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
426
    ) -> PrepareDecodeMetadata:
427
428
429
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
430
431
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
432
433
434
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
435

436
437
438
        if len(seq_group_metadata_list) == 0:
            return PrepareDecodeMetadata.empty()

439
440
        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
441
            assert seq_group_metadata.token_chunk_size == 1
442
443

            seq_ids = list(seq_group_metadata.seq_data.keys())
444
445
446
447
448
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

449
450
451
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
452
                input_tokens.append(generation_token)
453

454
455
                seq_len = seq_data.get_len()
                position = seq_len - 1
456
                input_positions.append(position)
457

458
459
460
461
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

462
463
464
465
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
466
467
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
468
                lora_prompt_mapping.append(lora_id)
469
470
471
472
473
474
475

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

476
477
478
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
479
480
481
482
483
484
485
486
487
488
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
489
490
491
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
492
493
                context_lens.append(1)
                block_tables.append([])
494
                lora_index_mapping.append(0)
495
496
            batch_size = graph_batch_size

497
498
499
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
                                           device=self.device)
500
501

        if use_captured_graph:
502
503
            # When using cuda-graph all these tensors should be
            # padded.
504
505
506
            assert context_lens_tensor.shape[0] == len(input_tokens)
            assert context_lens_tensor.shape[0] == len(input_positions)
            assert context_lens_tensor.shape[0] == len(slot_mapping)
507

508
509
510
511
512
513
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
514
            block_tables = torch.tensor(input_block_tables, device=self.device)
515
        else:
516
517
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
518
            block_tables = make_tensor_with_pad(
519
                block_tables,
520
                max_len=max_block_table_len,
521
522
                pad=0,
                dtype=torch.int,
523
                device=self.device,
524
            )
525

526
        attn_metadata = self.attn_backend.make_metadata(
527
            is_prompt=False,
528
            prompt_lens=None,
529
530
            prompt_lens_tensor=None,
            max_subquery_len=None,
531
            max_context_len=max_context_len,
532
            max_prompt_len=None,
533
534
            subquery_start_loc=None,
            seq_start_loc=None,
535
            context_lens=context_lens_tensor,
536
            block_tables=block_tables,
537
            use_cuda_graph=use_captured_graph,
538
        )
539
540
541
542
543
544
545
546
547
        return PrepareDecodeMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            slot_mapping=slot_mapping,
        )
548
549
550
551
552

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
553
        subquery_lens: Optional[List[int]],
554
555
556
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
557
        generators: List[torch.Generator] = []
558
        selected_token_start_idx = 0
559
560
561
562
563
        categorized_sample_indices: Dict[SamplingType,
                                         List[Tuple[int, int]]] = {
                                             t: []
                                             for t in SamplingType
                                         }
564
        categorized_sample_indices_start_idx = 0
565
        categorized_sampled_token_indices_start_idx = 0
566
567
568
569
570
571
572
573

        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
574
575
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
576
577
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
578
                    categorized_sample_indices_start_idx += subquery_len - 1
579
580

                categorized_sample_indices[
581
582
583
                    sampling_params.sampling_type].append(
                        (categorized_sample_indices_start_idx,
                         categorized_sampled_token_indices_start_idx))
584
                categorized_sample_indices_start_idx += 1
585
                categorized_sampled_token_indices_start_idx += 1
586
587
588
589

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
590
                              selected_token_start_idx + subquery_len - 1))
591
                selected_token_indices.append(selected_token_start_idx +
592
                                              subquery_len - 1)
593
                selected_token_start_idx += subquery_len
Nick Hill's avatar
Nick Hill committed
594
595
596

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
597
                        device=self.device).manual_seed(sampling_params.seed)
598
599
600
601
602
603
604
605
606
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
607
608
609
610
611
612
613
614
615
616
                        list(
                            zip(
                                range(
                                    categorized_sample_indices_start_idx,
                                    categorized_sample_indices_start_idx +
                                    num_seqs),
                                range(
                                    categorized_sampled_token_indices_start_idx,
                                    categorized_sampled_token_indices_start_idx
                                    + num_seqs))))
617
                categorized_sample_indices_start_idx += num_seqs
618
                categorized_sampled_token_indices_start_idx += num_seqs
619

Nick Hill's avatar
Nick Hill committed
620
621
622
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

623
624
625
626
        selected_token_indices = async_tensor_h2d(selected_token_indices,
                                                  dtype=torch.long,
                                                  target_device=self.device,
                                                  pin_memory=self.pin_memory)
627

628
        categorized_sample_indices = {
629
630
631
632
633
            t: maybe_expand_dim(
                async_tensor_h2d(seq_ids,
                                 dtype=torch.int,
                                 target_device=self.device,
                                 pin_memory=self.pin_memory), 2, 2)
634
635
636
637
638
639
640
641
642
643
644
645
646
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
647
            generators=generators,
648
649
650
        )
        return sampling_metadata

651
652
    def prepare_input_tensors(
        self,
653
        seq_group_metadata_list: List[SequenceGroupMetadata],
654
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
655
               Set[LoRARequest], LoRAMapping, torch.Tensor]:
656
        if self.is_driver_worker:
657
658
659
660
661
662
663
664
            prefill_reqs = []
            decode_reqs = []
            for seq_group_meta in seq_group_metadata_list:
                if seq_group_meta.is_prompt:
                    prefill_reqs.append(seq_group_meta)
                else:
                    decode_reqs.append(seq_group_meta)

665
            # Prepare input tensors.
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
            (
                input_tokens,
                input_positions,
                prefill_attn_metadata,
                prompt_lens,
                subquery_lens,
                lora_index_mapping,
                lora_prompt_mapping,
                lora_requests,
                multi_modal_input,
                slot_mapping,
            ) = self._prepare_prompt(prefill_reqs)
            (
                decode_input_tokens,
                decode_input_positions,
                decode_attn_metadata,
                decode_lora_index_mapping,
                decode_lora_prompt_mapping,
                decode_lora_requests,
                decode_slot_mapping,
            ) = self._prepare_decode(decode_reqs)
687
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
688
689
                                                     prompt_lens,
                                                     subquery_lens)
690

691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
            if not self.scheduler_config.chunked_prefill_enabled:
                assert (len(prefill_reqs) and len(decode_reqs)) == 0

            num_prefills = len(prompt_lens)
            num_prefill_tokens = len(input_tokens)
            num_decode_tokens = len(decode_input_tokens)

            # Coalesce tensors. Note that attn_metadata is currently not
            # coalesced for simplicity.
            input_tokens.extend(decode_input_tokens)
            input_positions.extend(decode_input_positions)
            slot_mapping.extend(decode_slot_mapping)
            lora_index_mapping.extend(decode_lora_index_mapping)
            lora_prompt_mapping.extend(decode_lora_prompt_mapping)
            lora_requests.update(decode_lora_requests)

            input_tokens = torch.tensor(input_tokens,
                                        dtype=torch.long,
                                        device=self.device)
            input_positions = torch.tensor(input_positions,
                                           dtype=torch.long,
                                           device=self.device)
            slot_mapping = torch.tensor(slot_mapping,
                                        dtype=torch.long,
                                        device=self.device)

717
718
            if self.lora_config:
                lora_mapping = LoRAMapping(
719
                    lora_index_mapping,
720
721
722
723
724
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

725
            # Broadcast the metadata.
726
727
728
729
730
731
732
733
734
735
            # If batch contains both prefill and decode, it sends 2 broadcasts.
            # If it only contains 1 type, it triggers a single broadcast.
            if (prefill_attn_metadata is not None
                    and decode_attn_metadata is not None):
                batch_type = BatchType.MIXED
            elif prefill_attn_metadata is not None:
                batch_type = BatchType.PREFILL
            else:
                batch_type = BatchType.DECODE

736
737
738
739
740
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
741
742
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
743
                "multi_modal_input": multi_modal_input,
744
745
746
747
748
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
                "batch_type": batch_type,
749
            }
750
751
752
            if prefill_attn_metadata is not None:
                metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
            else:
753
                assert decode_attn_metadata is not None
754
                metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
755
            broadcast_tensor_dict(metadata_dict, src=0)
756
757
758
759
760
761
762
763

            # Broadcast decode attn metadata for mixed batch type.
            # The additional broadcast costs 300us overhead on 4 A10 GPUs.
            # We can potentially reduce the overhead by coelescing tensors.
            if batch_type == BatchType.MIXED:
                assert decode_attn_metadata is not None
                metadata_dict = decode_attn_metadata.asdict_zerocopy()
                broadcast_tensor_dict(metadata_dict, src=0)
764
        else:
765
            metadata_dict = broadcast_tensor_dict(src=0)
766
767
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
768
769
            slot_mapping = metadata_dict.pop("slot_mapping")
            num_prefills = metadata_dict.pop("num_prefills")
770
771
772
773
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
774
            multi_modal_input = metadata_dict.pop("multi_modal_input")
775
776
777
778
779
780
781
782
783
784
785
786
787
            num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
            num_decode_tokens = metadata_dict.pop("num_decode_tokens")
            batch_type = metadata_dict.pop("batch_type")

            # Create an attention metadata.
            prefill_attn_metadata = None
            decode_attn_metadata = None
            if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
                prefill_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
            else:
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
788
789
790
791
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
792
                selected_token_indices=selected_token_indices,
793
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
794
                generators=None,
795
796
797
                perform_sampling=False,
            )

798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
            # if it is a mixed batch, decode attn_metadata is broadcasted
            # separately.
            if batch_type == BatchType.MIXED:
                metadata_dict = broadcast_tensor_dict(src=0)
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)

        attn_metadata = AttentionMetadata(
            num_prefills=num_prefills,
            slot_mapping=slot_mapping,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            prefill_metadata=prefill_attn_metadata,
            decode_metadata=decode_attn_metadata,
            kv_cache_dtype=self.kv_cache_dtype,
        )

815
        return (input_tokens, input_positions, attn_metadata,
816
817
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
818

819
820
821
    @torch.inference_mode()
    def execute_model(
        self,
822
        seq_group_metadata_list: List[SequenceGroupMetadata],
823
        kv_caches: List[torch.Tensor],
824
    ) -> Optional[SamplerOutput]:
825
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
826
827
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
828
829
830
831

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

832
833
834
835
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
836
837
838
839
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
840
841
842
843
844
845
846
847
848
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
849

850
851
852
853
854
855
856
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

857
858
        # Sample the next token.
        output = self.model.sample(
859
            logits=logits,
860
861
862
863
864
865
866
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
867
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
868
869
870
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

893
894
895
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
896
897
898
899
900
901
902
903
904
905
906
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
907
908
909
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
910
911
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
912
913
914
915
916
917
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
918
919
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
920
                multi_modal_data=fake_multi_modal_input,
921
922
923
924
925
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
926
        kv_caches = [None] * num_layers
927
        self.execute_model(seqs, kv_caches)
928
        torch.cuda.synchronize()
929
930
        return

931
932
933
934
935
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

936
    def set_active_loras(self, lora_requests: Set[LoRARequest],
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

957
    @torch.inference_mode()
958
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
959
960
961
962
963
964
965
966
967
968
969
970
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
971
972
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
973
        self.pynccl_backend = pynccl_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
974

975
976
977
978
979
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
980
981
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
982
983
984
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
985
986
987
988
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
989
990
991
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
992
993
994
995
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

996
997
998
999
1000
1001
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
1002
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
1003
1004
        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
Woosuk Kwon's avatar
Woosuk Kwon committed
1005
1006
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
1007
        # to PyTorch or pynccl if it is disabled or not supported.
1008
        with custom_all_reduce.capture():
1009
1010
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
1011
            for batch_size in reversed(batch_size_capture_list):
1012
                # Create dummy attn_metadata.
1013
                decode_metadata = self.attn_backend.make_metadata(
1014
1015
                    is_prompt=False,
                    prompt_lens=None,
1016
1017
                    prompt_lens_tensor=None,
                    max_subquery_len=None,
1018
                    max_context_len=self.max_context_len_to_capture,
1019
                    max_prompt_len=None,
1020
1021
                    subquery_start_loc=None,
                    seq_start_loc=None,
1022
1023
1024
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
1025
1026
1027
1028
1029
1030
1031
1032
                )
                attn_metadata = AttentionMetadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
                    prefill_metadata=None,
                    decode_metadata=decode_metadata,
1033
                    kv_cache_dtype=self.kv_cache_dtype,
1034
                )
1035

1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
1048
                    attn_metadata,
1049
                    memory_pool=self.graph_memory_pool,
1050
                )
1051
1052
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
1053
1054
1055
1056
1057
1058

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
1059
    def __del__(self) -> None:
1060
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
1061
1062
1063
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
1064
1065
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
1066
        self.graph_runners.clear()
1067
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
1068

1069
1070
1071
1072
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1073
1074
1075
1076
1077
1078
1079
1080

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1081
1082
1083
1084
1085
1086
1087
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1088
1089
1090
1091
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1092
1093
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1094
        memory_pool,
1095
        **kwargs,
1096
    ) -> None:
1097
        assert self._graph is None
1098
1099
1100
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1101
        with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
1102
            self.model(
1103
1104
1105
                input_ids,
                positions,
                kv_caches,
1106
                attn_metadata,
1107
                **kwargs,
1108
1109
1110
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
1111
1112
1113
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
1114
1115
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool):  # noqa: SIM117
1116
            with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
1117
1118
1119
1120
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
1121
                    attn_metadata,
1122
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
1123
1124
1125
                )
        torch.cuda.synchronize()

1126
1127
1128
1129
1130
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
1131
            "slot_mapping": attn_metadata.slot_mapping,
1132
1133
            "context_lens": attn_metadata.decode_metadata.context_lens,
            "block_tables": attn_metadata.decode_metadata.block_tables,
1134
1135
1136
1137
1138
1139
1140
1141
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1142
1143
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1144
        **kwargs,
1145
1146
1147
1148
1149
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1150
1151
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1152
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1153
                                                 non_blocking=True)
1154
1155
1156
1157
        self.input_buffers["context_lens"].copy_(
            attn_metadata.decode_metadata.context_lens, non_blocking=True)
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1158
1159
1160
1161
1162
1163
1164
1165
1166
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1167

1168
@contextlib.contextmanager
1169
1170
1171
1172
def _maybe_pynccl():
    if pynccl_utils.is_initialized(
    ) and not custom_all_reduce.is_initialized():
        with with_pynccl_for_all_reduce():
1173
1174
1175
1176
1177
            yield
    else:
        yield


1178
def _get_graph_batch_size(batch_size: int) -> int:
1179
1180
1181
1182
1183
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1184
1185
1186
1187
1188
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1189
1190
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input