model_runner.py 49.6 KB
Newer Older
1
import contextlib
2
import time
3
4
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
5

6
import numpy as np
7
import torch
8
import torch.nn as nn
9

10
11
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
                            get_attn_backend)
12
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
13
                         SchedulerConfig, VisionLanguageConfig)
14
15
16
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
                                                   pynccl_utils)
17
from vllm.logger import init_logger
18
19
20
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
21
from vllm.model_executor import SamplingMetadata
22
from vllm.model_executor.model_loader import get_model
23
from vllm.sampling_params import SamplingParams, SamplingType
24
25
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
26
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
27
28
                        is_pin_memory_available, make_tensor_with_pad,
                        maybe_expand_dim)
29
30
31
32

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
33
LORA_WARMUP_RANK = 8
34
35
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
38
39
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
40
41


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
class PreparePromptMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadataPerStage]
    prompt_lens: List[int]
    subquery_lens: List[int]
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    multi_modal_input: Optional[torch.Tensor]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PreparePromptMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
            prompt_lens=[],
            subquery_lens=[],
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            multi_modal_input=None,
            slot_mapping=[],
        )


class PrepareDecodeMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadata]
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PrepareDecodeMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            slot_mapping=[],
        )


# How batches are constructed.
class BatchType(IntEnum):
    # Every batch is prefill.
    PREFILL = 0
    # Every batch is decode.
    DECODE = 1
    # Batch is a mixture of prefill and decode.
    MIXED = 2


102
103
104
105
106
107
108
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
109
        device_config: DeviceConfig,
110
        lora_config: Optional[LoRAConfig],
111
        kv_cache_dtype: Optional[str] = "auto",
112
        is_driver_worker: bool = False,
113
        vision_language_config: Optional[VisionLanguageConfig] = None,
114
115
116
117
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
118
        self.lora_config = lora_config
119
        self.is_driver_worker = is_driver_worker
120

Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
125
126
127
128
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

129
130
        self.model = None
        self.block_size = None  # Set after initial profiling.
131
        self.lora_manager = None
132

133
134
135
136
137
138
139
140
141
142
143
144
145
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
146
        self.pin_memory = is_pin_memory_available()
147
        self.kv_cache_dtype = kv_cache_dtype
148
        self.vision_language_config = vision_language_config
149

150
151
152
        self.attn_backend = get_attn_backend(
            self.model_config.dtype if model_config is not None else None)

153
    def load_model(self) -> None:
154
        with CudaMemoryProfiler() as m:
155
156
157
158
159
160
161
            self.model = get_model(
                self.model_config,
                self.device_config,
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config)
162
163

        self.model_memory_usage = m.consumed_memory
164
165
        logger.info(f"Loading model weights took "
                    f"{self.model_memory_usage / float(2**30):.4f} GB")
166
167

        if self.lora_config:
168
169
170
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
171
172
173
174
175
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
176
177
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
178
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
179
180
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
181
            self.model = self.lora_manager.create_lora_manager(self.model)
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
                    raise RuntimeError("Using FP8 KV cache and scaling "
                                       "factors provided but model "
                                       f"{self.model.__class__} does not "
                                       "support loading scaling factors.")
            else:
                logger.warn("Using FP8 KV cache but no scaling factors "
                            "provided. Defaulting to scaling factors of 1.0. "
                            "This may lead to less accurate results!")
        elif self.model_config.quantization_param_path is not None:
            logger.warn("KV cache scaling factors provided, "
                        "but the KV cache data type is not FP8. "
                        "KV cache scaling factors will not be used.")

203
204
205
    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

206
        self.graph_block_tables = np.zeros(
207
208
209
210
211
212
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
        return (self.max_context_len_to_capture + block_size - 1) // block_size
213

214
215
216
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
217
    ) -> PreparePromptMetadata:
218
219
220
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
221
222
223
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
224
225

        prompt_lens: List[int] = []
226
227
228
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
229
        multi_modal_input_list: List[torch.Tensor] = []
230

231
232
233
        if len(seq_group_metadata_list) == 0:
            return PreparePromptMetadata.empty()

234
235
236
237
238
239
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

240
241
242
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
243
244
                    and not (computed_block_nums is None
                             or computed_block_nums == [])):
245
246
247
248
249
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
250
            seq_data = seq_group_metadata.seq_data[seq_id]
251
252
253
254
255
256
            computed_len = seq_data.get_num_computed_tokens()
            # We should use get_len here because in case of preemption
            # it contains output tokens.
            prefill_end = min(seq_data.get_len(),
                              computed_len + token_chunk_size)
            prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
257
            prompt_len = prefill_end
258
            prompt_lens.append(prompt_len)
259
260
261
262
263
264
265
266

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
                computed_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[computed_len:]
                prefix_block_tables.append(computed_block_nums)
267
268
269
270
271
272
273
274
            elif self.scheduler_config.chunked_prefill_enabled:
                if seq_group_metadata.block_tables is not None:
                    # Prefill has chunked before.
                    block_table = seq_group_metadata.block_tables[seq_id]
                    prefix_block_tables.append(block_table)
                else:
                    # The first prefill.
                    prefix_block_tables.append([])
275
276
            else:
                prefix_block_tables.append([])
277
278
279
280
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
                assert computed_len == 0

281
            # actual prompt lens
282
            context_lens.append(computed_len)
283
            subquery_lens.append(prompt_len - computed_len)
284

285
            input_tokens.extend(prompt_tokens)
286
287
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
288
            input_positions.extend(list(range(computed_len, prefill_end)))
289
290
291
292
293
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

294
            lora_index_mapping += [lora_id] * (prompt_len - computed_len)
295
296
            lora_prompt_mapping.extend(
                [lora_id] *
297
                (prompt_len - computed_len
298
299
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

300
301
302
303
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

304
305
306
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
307
                slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
308
309
310
311
312
313
314
315
316
317
318
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
319
                assert computed_len == 0, (
320
321
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
322
                start_idx = max(0, prompt_len - self.sliding_window)
323
324

            for i in range(computed_len, prefill_end):
325
                if i < start_idx:
326
                    slot_mapping.append(_PAD_SLOT_ID)
327
328
329
330
331
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
332
333
334
                slot_mapping.append(slot)

        max_subquery_len = max(subquery_lens)
335
        max_prompt_len = max(prompt_lens)
336
337
        assert max_subquery_len > 0

338
339
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
340
                                           device=self.device)
341
342
343
344
345
346
347
348
349
350

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

351
352
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
353
        block_tables = make_tensor_with_pad(
354
355
356
357
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
358
            device=self.device,
359
        )
360
361
362
363
364
365
366
367
368
369

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
        subquery_lens_tensor = torch.tensor(subquery_lens,
                                            dtype=torch.long,
                                            device=self.device)
        subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
                                         dtype=torch.int32,
                                         device=self.device)

370
371
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
372
                                          device=self.device)
373
374
375
376
377
378
379
380
381
382
383
384
385
        seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(subquery_lens_tensor,
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

        torch.cumsum(prompt_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
386

387
        attn_metadata = self.attn_backend.make_metadata(
388
            is_prompt=True,
389
390
391
            prompt_lens=prompt_lens,
            prompt_lens_tensor=prompt_lens_tensor,
            max_subquery_len=max_subquery_len,
392
            max_context_len=None,
393
            max_prompt_len=max_prompt_len,
394
395
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
396
397
            context_lens=context_lens_tensor,
            block_tables=block_tables,
398
            use_cuda_graph=False,
399
        )
400
401
402
403
404
405
406
407
408
409
410
411
412

        return PreparePromptMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            prompt_lens=prompt_lens,
            subquery_lens=subquery_lens,
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            multi_modal_input=multi_modal_input,
            slot_mapping=slot_mapping,
        )
413
414
415
416

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
417
    ) -> PrepareDecodeMetadata:
418
419
420
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
421
422
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
423
424
425
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
426

427
428
429
        if len(seq_group_metadata_list) == 0:
            return PrepareDecodeMetadata.empty()

430
431
        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
432
            assert seq_group_metadata.token_chunk_size == 1
433
434

            seq_ids = list(seq_group_metadata.seq_data.keys())
435
436
437
438
439
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

440
441
442
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
443
                input_tokens.append(generation_token)
444

445
446
                seq_len = seq_data.get_len()
                position = seq_len - 1
447
                input_positions.append(position)
448

449
450
451
452
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

453
454
455
456
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
457
458
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
459
                lora_prompt_mapping.append(lora_id)
460
461
462
463
464
465
466

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

467
468
469
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
470
471
472
473
474
475
476
477
478
479
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
480
481
482
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
483
484
                context_lens.append(1)
                block_tables.append([])
485
                lora_index_mapping.append(0)
486
487
            batch_size = graph_batch_size

488
489
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
490
                                    device=self.device)
491
492

        if use_captured_graph:
493
494
            # When using cuda-graph all these tensors should be
            # padded.
495
496
497
            assert context_lens.shape[0] == len(input_tokens)
            assert context_lens.shape[0] == len(input_positions)
            assert context_lens.shape[0] == len(slot_mapping)
498

499
500
501
502
503
504
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
505
            block_tables = torch.tensor(input_block_tables, device=self.device)
506
        else:
507
508
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
509
            block_tables = make_tensor_with_pad(
510
                block_tables,
511
                max_len=max_block_table_len,
512
513
                pad=0,
                dtype=torch.int,
514
                device=self.device,
515
            )
516

517
        attn_metadata = self.attn_backend.make_metadata(
518
            is_prompt=False,
519
            prompt_lens=None,
520
521
            prompt_lens_tensor=None,
            max_subquery_len=None,
522
            max_context_len=max_context_len,
523
            max_prompt_len=None,
524
525
            subquery_start_loc=None,
            seq_start_loc=None,
526
527
            context_lens=context_lens,
            block_tables=block_tables,
528
            use_cuda_graph=use_captured_graph,
529
        )
530
531
532
533
534
535
536
537
538
        return PrepareDecodeMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            slot_mapping=slot_mapping,
        )
539
540
541
542
543

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
544
        subquery_lens: Optional[List[int]],
545
546
547
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
548
        generators: List[torch.Generator] = []
549
550
551
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
552
        categorized_sampled_token_indices_start_idx = 0
553
554
555
556
557
558
559
560

        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
561
562
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
563
564
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
565
                    categorized_sample_indices_start_idx += subquery_len - 1
566
567

                categorized_sample_indices[
568
569
570
571
                    sampling_params.sampling_type].append([
                        categorized_sample_indices_start_idx,
                        categorized_sampled_token_indices_start_idx
                    ])
572
                categorized_sample_indices_start_idx += 1
573
                categorized_sampled_token_indices_start_idx += 1
574
575
576
577

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
578
                              selected_token_start_idx + subquery_len - 1))
579
                selected_token_indices.append(selected_token_start_idx +
580
                                              subquery_len - 1)
581
                selected_token_start_idx += subquery_len
Nick Hill's avatar
Nick Hill committed
582
583
584

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
585
                        device=self.device).manual_seed(sampling_params.seed)
586
587
588
589
590
591
592
593
594
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
595
596
597
598
599
600
601
602
603
                        zip(
                            range(
                                categorized_sample_indices_start_idx,
                                categorized_sample_indices_start_idx +
                                num_seqs),
                            range(
                                categorized_sampled_token_indices_start_idx,
                                categorized_sampled_token_indices_start_idx +
                                num_seqs)))
604
                categorized_sample_indices_start_idx += num_seqs
605
                categorized_sampled_token_indices_start_idx += num_seqs
606

Nick Hill's avatar
Nick Hill committed
607
608
609
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

610
611
612
613
        selected_token_indices = async_tensor_h2d(selected_token_indices,
                                                  dtype=torch.long,
                                                  target_device=self.device,
                                                  pin_memory=self.pin_memory)
614

615
        categorized_sample_indices = {
616
617
618
619
620
            t: maybe_expand_dim(
                async_tensor_h2d(seq_ids,
                                 dtype=torch.int,
                                 target_device=self.device,
                                 pin_memory=self.pin_memory), 2, 2)
621
622
623
624
625
626
627
628
629
630
631
632
633
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
634
            generators=generators,
635
636
637
        )
        return sampling_metadata

638
639
640
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
641
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
642
               Set[int], LoRAMapping, torch.Tensor]:
643
        if self.is_driver_worker:
644
645
646
647
648
649
650
651
            prefill_reqs = []
            decode_reqs = []
            for seq_group_meta in seq_group_metadata_list:
                if seq_group_meta.is_prompt:
                    prefill_reqs.append(seq_group_meta)
                else:
                    decode_reqs.append(seq_group_meta)

652
            # Prepare input tensors.
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
            (
                input_tokens,
                input_positions,
                prefill_attn_metadata,
                prompt_lens,
                subquery_lens,
                lora_index_mapping,
                lora_prompt_mapping,
                lora_requests,
                multi_modal_input,
                slot_mapping,
            ) = self._prepare_prompt(prefill_reqs)
            (
                decode_input_tokens,
                decode_input_positions,
                decode_attn_metadata,
                decode_lora_index_mapping,
                decode_lora_prompt_mapping,
                decode_lora_requests,
                decode_slot_mapping,
            ) = self._prepare_decode(decode_reqs)
674
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
675
676
                                                     prompt_lens,
                                                     subquery_lens)
677

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
            if not self.scheduler_config.chunked_prefill_enabled:
                assert (len(prefill_reqs) and len(decode_reqs)) == 0

            num_prefills = len(prompt_lens)
            num_prefill_tokens = len(input_tokens)
            num_decode_tokens = len(decode_input_tokens)

            # Coalesce tensors. Note that attn_metadata is currently not
            # coalesced for simplicity.
            input_tokens.extend(decode_input_tokens)
            input_positions.extend(decode_input_positions)
            slot_mapping.extend(decode_slot_mapping)
            lora_index_mapping.extend(decode_lora_index_mapping)
            lora_prompt_mapping.extend(decode_lora_prompt_mapping)
            lora_requests.update(decode_lora_requests)

            input_tokens = torch.tensor(input_tokens,
                                        dtype=torch.long,
                                        device=self.device)
            input_positions = torch.tensor(input_positions,
                                           dtype=torch.long,
                                           device=self.device)
            slot_mapping = torch.tensor(slot_mapping,
                                        dtype=torch.long,
                                        device=self.device)

704
705
            if self.lora_config:
                lora_mapping = LoRAMapping(
706
                    lora_index_mapping,
707
708
709
710
711
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

712
            # Broadcast the metadata.
713
714
715
716
717
718
719
720
721
722
            # If batch contains both prefill and decode, it sends 2 broadcasts.
            # If it only contains 1 type, it triggers a single broadcast.
            if (prefill_attn_metadata is not None
                    and decode_attn_metadata is not None):
                batch_type = BatchType.MIXED
            elif prefill_attn_metadata is not None:
                batch_type = BatchType.PREFILL
            else:
                batch_type = BatchType.DECODE

723
724
725
726
727
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
728
729
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
730
                "multi_modal_input": multi_modal_input,
731
732
733
734
735
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
                "batch_type": batch_type,
736
            }
737
738
739
740
            if prefill_attn_metadata is not None:
                metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
            else:
                metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
741
            broadcast_tensor_dict(metadata_dict, src=0)
742
743
744
745
746
747
748
749

            # Broadcast decode attn metadata for mixed batch type.
            # The additional broadcast costs 300us overhead on 4 A10 GPUs.
            # We can potentially reduce the overhead by coelescing tensors.
            if batch_type == BatchType.MIXED:
                assert decode_attn_metadata is not None
                metadata_dict = decode_attn_metadata.asdict_zerocopy()
                broadcast_tensor_dict(metadata_dict, src=0)
750
        else:
751
            metadata_dict = broadcast_tensor_dict(src=0)
752
753
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
754
755
            slot_mapping = metadata_dict.pop("slot_mapping")
            num_prefills = metadata_dict.pop("num_prefills")
756
757
758
759
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
760
            multi_modal_input = metadata_dict.pop("multi_modal_input")
761
762
763
764
765
766
767
768
769
770
771
772
773
            num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
            num_decode_tokens = metadata_dict.pop("num_decode_tokens")
            batch_type = metadata_dict.pop("batch_type")

            # Create an attention metadata.
            prefill_attn_metadata = None
            decode_attn_metadata = None
            if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
                prefill_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
            else:
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
774
775
776
777
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
778
                selected_token_indices=selected_token_indices,
779
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
780
                generators=None,
781
782
783
                perform_sampling=False,
            )

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
            # if it is a mixed batch, decode attn_metadata is broadcasted
            # separately.
            if batch_type == BatchType.MIXED:
                metadata_dict = broadcast_tensor_dict(src=0)
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)

        attn_metadata = AttentionMetadata(
            num_prefills=num_prefills,
            slot_mapping=slot_mapping,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            prefill_metadata=prefill_attn_metadata,
            decode_metadata=decode_attn_metadata,
            kv_cache_dtype=self.kv_cache_dtype,
        )

801
        return (input_tokens, input_positions, attn_metadata,
802
803
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
804

805
806
807
    @torch.inference_mode()
    def execute_model(
        self,
808
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
809
        kv_caches: List[torch.Tensor],
810
    ) -> Optional[SamplerOutput]:
811
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
812
813
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
814
815
816
817

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

818
819
820
821
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
822
823
824
825
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
826
827
828
829
830
831
832
833
834
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
835

836
837
838
839
840
841
842
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

843
844
        # Sample the next token.
        output = self.model.sample(
845
            logits=logits,
846
847
848
849
850
851
852
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
853
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
854
855
856
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

879
880
881
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
882
883
884
885
886
887
888
889
890
891
892
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
893
894
895
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
896
897
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
898
899
900
901
902
903
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
904
905
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
906
                multi_modal_data=fake_multi_modal_input,
907
908
909
910
911
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
912
        kv_caches = [None] * num_layers
913
        self.execute_model(seqs, kv_caches)
914
        torch.cuda.synchronize()
915
916
        return

917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

943
    @torch.inference_mode()
944
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
945
946
947
948
949
950
951
952
953
954
955
956
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
957
958
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
959
        self.pynccl_backend = pynccl_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
960

961
962
963
964
965
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
966
967
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
968
969
970
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
971
972
973
974
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
975
976
977
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
978
979
980
981
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

982
983
984
985
986
987
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
988
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
989
990
        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
Woosuk Kwon's avatar
Woosuk Kwon committed
991
992
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
993
        # to PyTorch or pynccl if it is disabled or not supported.
994
        with custom_all_reduce.capture():
995
996
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
997
            for batch_size in reversed(batch_size_capture_list):
998
                # Create dummy attn_metadata.
999
                decode_metadata = self.attn_backend.make_metadata(
1000
1001
                    is_prompt=False,
                    prompt_lens=None,
1002
1003
                    prompt_lens_tensor=None,
                    max_subquery_len=None,
1004
                    max_context_len=self.max_context_len_to_capture,
1005
                    max_prompt_len=None,
1006
1007
                    subquery_start_loc=None,
                    seq_start_loc=None,
1008
1009
1010
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
1011
1012
1013
1014
1015
1016
1017
1018
                )
                attn_metadata = AttentionMetadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
                    prefill_metadata=None,
                    decode_metadata=decode_metadata,
1019
                    kv_cache_dtype=self.kv_cache_dtype,
1020
                )
1021

1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
1034
                    attn_metadata,
1035
                    memory_pool=self.graph_memory_pool,
1036
                )
1037
1038
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
1039
1040
1041
1042
1043
1044

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
1045
    def __del__(self) -> None:
1046
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
1047
1048
1049
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
1050
1051
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
1052
        self.graph_runners.clear()
1053
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
1054

1055
1056
1057
1058
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1072
1073
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1074
        memory_pool,
1075
        **kwargs,
1076
1077
1078
1079
1080
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1081
        with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
1082
            self.model(
1083
1084
1085
                input_ids,
                positions,
                kv_caches,
1086
                attn_metadata,
1087
                **kwargs,
1088
1089
1090
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
1091
1092
1093
1094
1095
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
1096
            with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
1097
1098
1099
1100
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
1101
                    attn_metadata,
1102
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
1103
1104
1105
                )
        torch.cuda.synchronize()

1106
1107
1108
1109
1110
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
1111
            "slot_mapping": attn_metadata.slot_mapping,
1112
1113
            "context_lens": attn_metadata.decode_metadata.context_lens,
            "block_tables": attn_metadata.decode_metadata.block_tables,
1114
1115
1116
1117
1118
1119
1120
1121
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1122
1123
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1124
        **kwargs,
1125
1126
1127
1128
1129
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1130
1131
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1132
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1133
                                                 non_blocking=True)
1134
1135
1136
1137
        self.input_buffers["context_lens"].copy_(
            attn_metadata.decode_metadata.context_lens, non_blocking=True)
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1138
1139
1140
1141
1142
1143
1144
1145
1146
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1147

1148
@contextlib.contextmanager
1149
1150
1151
1152
def _maybe_pynccl():
    if pynccl_utils.is_initialized(
    ) and not custom_all_reduce.is_initialized():
        with with_pynccl_for_all_reduce():
1153
1154
1155
1156
1157
            yield
    else:
        yield


1158
def _get_graph_batch_size(batch_size: int) -> int:
1159
1160
1161
1162
1163
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1164
1165
1166
1167
1168
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1169
1170
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input