model_runner.py 48.3 KB
Newer Older
1
import contextlib
2
import time
3
4
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
5

6
import numpy as np
7
import torch
8
import torch.nn as nn
9

10
11
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
                            get_attn_backend)
12
13
14
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
15
16
17
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
                                                   pynccl_utils)
18
from vllm.logger import init_logger
19
20
21
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
22
from vllm.model_executor import SamplingMetadata
23
from vllm.model_executor.model_loader import get_model
24
from vllm.sampling_params import SamplingParams
25
26
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
27
28
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
                        is_pin_memory_available, make_tensor_with_pad)
29
30
31
32

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
33
LORA_WARMUP_RANK = 8
34
35
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
38
39
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
40
41


42
43
44
45
class PreparePromptMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadataPerStage]
46
47
    seq_lens: List[int]
    query_lens: List[int]
48
49
50
51
52
53
54
55
56
57
58
59
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    multi_modal_input: Optional[torch.Tensor]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PreparePromptMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
60
61
            seq_lens=[],
            query_lens=[],
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            multi_modal_input=None,
            slot_mapping=[],
        )


class PrepareDecodeMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadata]
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PrepareDecodeMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            slot_mapping=[],
        )


# How batches are constructed.
class BatchType(IntEnum):
    # Every batch is prefill.
    PREFILL = 0
    # Every batch is decode.
    DECODE = 1
    # Batch is a mixture of prefill and decode.
    MIXED = 2


102
103
104
105
106
107
108
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
109
        device_config: DeviceConfig,
110
        cache_config: CacheConfig,
111
        load_config: LoadConfig,
112
        lora_config: Optional[LoRAConfig],
113
        kv_cache_dtype: Optional[str] = "auto",
114
        is_driver_worker: bool = False,
115
        vision_language_config: Optional[VisionLanguageConfig] = None,
116
117
118
119
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
120
121
        self.device_config = device_config
        self.cache_config = cache_config
122
        self.lora_config = lora_config
123
        self.load_config = load_config
124
        self.is_driver_worker = is_driver_worker
125
        self.vision_language_config = vision_language_config
126

127
        self.device = self.device_config.device
128
        self.pin_memory = is_pin_memory_available()
129

130
131
132
133
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
134
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
135
136
137
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
        # When using CUDA graph, the input block tables must be padded to
138
        # max_seq_len_to_capture. However, creating the block table in
139
140
141
142
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
143
144
145
146
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
        self.attn_backend = get_attn_backend(self.model_config.dtype)
147

148
149
        # Lazy initialization
        self.model: torch.nn.Module  # Set after load_model
150
151
        # Set if the backend is flashinfer.
        self.flashinfer_workspace_buffer: torch.Tensor
152
153
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
154

155
    def load_model(self) -> None:
156
        with CudaMemoryProfiler() as m:
157
            self.model = get_model(
158
159
160
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
161
162
163
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
164
165
                scheduler_config=self.scheduler_config,
            )
166
167

        self.model_memory_usage = m.consumed_memory
168
169
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
170
171

        if self.lora_config:
172
173
174
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
175
176
177
178
179
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
180
181
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
182
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
183
184
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
185
            self.model = self.lora_manager.create_lora_manager(self.model)
186

187
188
189
190
191
192
193
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
194
195
196
197
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
198
            else:
199
200
201
202
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
203
        elif self.model_config.quantization_param_path is not None:
204
205
206
            logger.warning("KV cache scaling factors provided, "
                           "but the KV cache data type is not FP8. "
                           "KV cache scaling factors will not be used.")
207

208
209
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
210
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
211

212
213
214
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
215
    ) -> PreparePromptMetadata:
216
217
218
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
219
220
221
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
222

223
        seq_lens: List[int] = []
224
        context_lens: List[int] = []
225
        query_lens: List[int] = []
226
        prefix_block_tables: List[List[int]] = []
227
        multi_modal_input_list: List[torch.Tensor] = []
228

229
230
231
        if len(seq_group_metadata_list) == 0:
            return PreparePromptMetadata.empty()

232
233
234
235
236
237
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

238
239
240
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
241
242
                    and not (computed_block_nums is None
                             or computed_block_nums == [])):
243
244
245
246
247
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
248
            seq_data = seq_group_metadata.seq_data[seq_id]
249
            context_len = seq_data.get_num_computed_tokens()
250
251
            # We should use get_len here because in case of preemption
            # it contains output tokens.
252
253
254
            seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
            prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
            seq_lens.append(seq_len)
255
256
257
258
259

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
260
261
                context_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[context_len:]
262
                prefix_block_tables.append(computed_block_nums)
263
264
265
266
267
268
269
270
            elif self.scheduler_config.chunked_prefill_enabled:
                if seq_group_metadata.block_tables is not None:
                    # Prefill has chunked before.
                    block_table = seq_group_metadata.block_tables[seq_id]
                    prefix_block_tables.append(block_table)
                else:
                    # The first prefill.
                    prefix_block_tables.append([])
271
272
            else:
                prefix_block_tables.append([])
273
274
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
275
                assert context_len == 0
276

277
            # actual prompt lens
278
279
            context_lens.append(context_len)
            query_lens.append(seq_len - context_len)
280

281
            input_tokens.extend(prompt_tokens)
282
283
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
284
            input_positions.extend(list(range(context_len, seq_len)))
285
286
287
288
289
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

290
            lora_index_mapping += [lora_id] * (seq_len - context_len)
291
292
            lora_prompt_mapping.extend(
                [lora_id] *
293
                (seq_len - context_len
294
295
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

296
297
298
299
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

300
301
302
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
303
                slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
304
305
306
307
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
308

309
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
310
            # where start_idx is max(0, seq_len - sliding_window).
311
312
313
314
315
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
316
                assert context_len == 0, (
317
318
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
319
                start_idx = max(0, seq_len - self.sliding_window)
320

321
            for i in range(context_len, seq_len):
322
                if i < start_idx:
323
                    slot_mapping.append(_PAD_SLOT_ID)
324
325
326
327
328
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
329
330
                slot_mapping.append(slot)

331
332
333
        max_query_len = max(query_lens)
        max_seq_len = max(seq_lens)
        assert max_query_len > 0
334

335
336
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
337
                                           device=self.device)
338
339
340
341
342
343
344
345
346
347

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

348
349
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
350
        block_tables = make_tensor_with_pad(
351
352
353
354
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
355
            device=self.device,
356
        )
357
358
359

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
360
361
362
363
        query_lens_tensor = torch.tensor(query_lens,
                                         dtype=torch.long,
                                         device=self.device)
        subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
364
365
366
                                         dtype=torch.int32,
                                         device=self.device)

367
368
369
370
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
371
372
373
                                    dtype=torch.int32,
                                    device=self.device)

374
        torch.cumsum(query_lens_tensor,
375
376
377
378
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

379
        torch.cumsum(seq_lens_tensor,
380
381
382
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
383

384
        if self.attn_backend.get_name() == "flashinfer":
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
            attn_metadata = self.attn_backend.make_metadata(
                is_prompt=True,
                use_cuda_graph=False,
                seq_start_loc=seq_start_loc,
                max_seq_len=max_seq_len,
                block_tables=block_tables)
        else:
            attn_metadata = self.attn_backend.make_metadata(
                is_prompt=True,
                seq_lens=seq_lens,
                seq_lens_tensor=seq_lens_tensor,
                max_query_len=max_query_len,
                max_seq_len=max_seq_len,
                subquery_start_loc=subquery_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
                block_tables=block_tables,
                use_cuda_graph=False,
            )
404
405
406
407
408

        return PreparePromptMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
409
410
            seq_lens=seq_lens,
            query_lens=query_lens,
411
412
413
414
415
416
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            multi_modal_input=multi_modal_input,
            slot_mapping=slot_mapping,
        )
417
418
419
420

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
421
    ) -> PrepareDecodeMetadata:
422
423
424
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
425
        seq_lens: List[int] = []
426
        block_tables: List[List[int]] = []
427
428
429
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
430

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        # The following fields are only for flashinfer
        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        paged_kv_last_page_len: List[int] = []

449
450
451
        if len(seq_group_metadata_list) == 0:
            return PrepareDecodeMetadata.empty()

452
453
        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
454
            assert seq_group_metadata.token_chunk_size == 1
455
456

            seq_ids = list(seq_group_metadata.seq_data.keys())
457
458
459
460
461
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

462
463
464
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
465
                input_tokens.append(generation_token)
466

467
468
                seq_len = seq_data.get_len()
                position = seq_len - 1
469
                input_positions.append(position)
470

471
                seq_len = seq_len if self.sliding_window is None else min(
472
                    seq_len, self.sliding_window)
473
                seq_lens.append(seq_len)
474

475
476
477
478
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
479
480
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
481
                lora_prompt_mapping.append(lora_id)
482
483
484
485
486
487
488

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

489
490
491
492
493
494
495
                paged_kv_indices.extend(block_table)
                paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
                last_page_len = seq_data.get_len() % self.block_size
                if last_page_len == 0:
                    last_page_len = self.block_size
                paged_kv_last_page_len.append(last_page_len)

496
497
498
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
499
        batch_size = len(input_tokens)
500
501
502
503
        max_seq_len = max(seq_lens)
        use_captured_graph = (not self.model_config.enforce_eager
                              and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
                              and max_seq_len <= self.max_seq_len_to_capture)
504
505
506
507
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
508
509
510
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
511
                seq_lens.append(1)
512
                block_tables.append([])
513
                lora_index_mapping.append(0)
514
515
            batch_size = graph_batch_size

516
517
518
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
519
520

        if use_captured_graph:
521
522
            # When using cuda-graph all these tensors should be
            # padded.
523
524
525
            assert seq_lens_tensor.shape[0] == len(input_tokens)
            assert seq_lens_tensor.shape[0] == len(input_positions)
            assert seq_lens_tensor.shape[0] == len(slot_mapping)
526

527
528
529
530
531
532
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
533
            block_tables = torch.tensor(input_block_tables, device=self.device)
534
        else:
535
536
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
537
            block_tables = make_tensor_with_pad(
538
                block_tables,
539
                max_len=max_block_table_len,
540
541
                pad=0,
                dtype=torch.int,
542
                device=self.device,
543
            )
544

545
        if self.attn_backend.get_name() == "flashinfer":
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            if not hasattr(self, "flashinfer_workspace_buffer"):
                # Allocate 16MB workspace buffer
                # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
                self.flashinfer_workspace_buffer = torch.empty(
                    16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
            paged_kv_indptr = torch.tensor(paged_kv_indptr,
                                           dtype=torch.int,
                                           device=self.device)
            paged_kv_indices = torch.tensor(paged_kv_indices,
                                            dtype=torch.int,
                                            device=self.device)
            paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
                                                  dtype=torch.int,
                                                  device=self.device)
            kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
                                                      self.model_config.dtype)

            attn_metadata = self.attn_backend.make_metadata(
                is_prompt=False,
                use_cuda_graph=False,
                workspace_buffer=self.flashinfer_workspace_buffer,
                paged_kv_indptr=paged_kv_indptr,
                paged_kv_indices=paged_kv_indices,
                paged_kv_last_page_len=paged_kv_last_page_len,
                num_qo_heads=self.model_config.get_num_attention_heads(
                    self.parallel_config),
                num_kv_heads=self.model_config.get_num_kv_heads(
                    self.parallel_config),
                head_dim=self.model_config.get_head_size(),
                page_size=self.block_size,
                data_type=kv_cache_dtype)
        else:
            attn_metadata = self.attn_backend.make_metadata(
                is_prompt=False,
                seq_lens=None,
                seq_lens_tensor=seq_lens_tensor,
                max_query_len=None,
                max_seq_len=max_seq_len,
                subquery_start_loc=None,
                seq_start_loc=None,
                context_lens_tensor=None,
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
590
591
592
593
594
595
596
597
598
        return PrepareDecodeMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            slot_mapping=slot_mapping,
        )
599

600
601
    def prepare_input_tensors(
        self,
602
        seq_group_metadata_list: List[SequenceGroupMetadata],
603
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
604
               Set[LoRARequest], LoRAMapping, torch.Tensor]:
605
        if self.is_driver_worker:
606
607
608
609
610
611
612
613
            prefill_reqs = []
            decode_reqs = []
            for seq_group_meta in seq_group_metadata_list:
                if seq_group_meta.is_prompt:
                    prefill_reqs.append(seq_group_meta)
                else:
                    decode_reqs.append(seq_group_meta)

614
            # Prepare input tensors.
615
616
617
618
            (
                input_tokens,
                input_positions,
                prefill_attn_metadata,
619
620
                seq_lens,
                query_lens,
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
                lora_index_mapping,
                lora_prompt_mapping,
                lora_requests,
                multi_modal_input,
                slot_mapping,
            ) = self._prepare_prompt(prefill_reqs)
            (
                decode_input_tokens,
                decode_input_positions,
                decode_attn_metadata,
                decode_lora_index_mapping,
                decode_lora_prompt_mapping,
                decode_lora_requests,
                decode_slot_mapping,
            ) = self._prepare_decode(decode_reqs)
636
            sampling_metadata = SamplingMetadata.prepare(
637
638
                seq_group_metadata_list, seq_lens, query_lens, self.device,
                self.pin_memory)
639

640
641
642
            if not self.scheduler_config.chunked_prefill_enabled:
                assert (len(prefill_reqs) and len(decode_reqs)) == 0

643
            num_prefills = len(seq_lens)
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
            num_prefill_tokens = len(input_tokens)
            num_decode_tokens = len(decode_input_tokens)

            # Coalesce tensors. Note that attn_metadata is currently not
            # coalesced for simplicity.
            input_tokens.extend(decode_input_tokens)
            input_positions.extend(decode_input_positions)
            slot_mapping.extend(decode_slot_mapping)
            lora_index_mapping.extend(decode_lora_index_mapping)
            lora_prompt_mapping.extend(decode_lora_prompt_mapping)
            lora_requests.update(decode_lora_requests)

            input_tokens = torch.tensor(input_tokens,
                                        dtype=torch.long,
                                        device=self.device)
            input_positions = torch.tensor(input_positions,
                                           dtype=torch.long,
                                           device=self.device)
            slot_mapping = torch.tensor(slot_mapping,
                                        dtype=torch.long,
                                        device=self.device)

666
667
            if self.lora_config:
                lora_mapping = LoRAMapping(
668
                    lora_index_mapping,
669
670
671
672
673
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

674
            # Broadcast the metadata.
675
676
677
678
679
680
681
682
683
684
            # If batch contains both prefill and decode, it sends 2 broadcasts.
            # If it only contains 1 type, it triggers a single broadcast.
            if (prefill_attn_metadata is not None
                    and decode_attn_metadata is not None):
                batch_type = BatchType.MIXED
            elif prefill_attn_metadata is not None:
                batch_type = BatchType.PREFILL
            else:
                batch_type = BatchType.DECODE

685
686
687
688
689
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
690
691
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
692
                "multi_modal_input": multi_modal_input,
693
694
695
696
697
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
                "batch_type": batch_type,
698
            }
699
700
701
            if prefill_attn_metadata is not None:
                metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
            else:
702
                assert decode_attn_metadata is not None
703
                metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
704
            broadcast_tensor_dict(metadata_dict, src=0)
705
706
707
708
709
710
711
712

            # Broadcast decode attn metadata for mixed batch type.
            # The additional broadcast costs 300us overhead on 4 A10 GPUs.
            # We can potentially reduce the overhead by coelescing tensors.
            if batch_type == BatchType.MIXED:
                assert decode_attn_metadata is not None
                metadata_dict = decode_attn_metadata.asdict_zerocopy()
                broadcast_tensor_dict(metadata_dict, src=0)
713
        else:
714
            metadata_dict = broadcast_tensor_dict(src=0)
715
716
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
717
718
            slot_mapping = metadata_dict.pop("slot_mapping")
            num_prefills = metadata_dict.pop("num_prefills")
719
720
721
722
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
723
            multi_modal_input = metadata_dict.pop("multi_modal_input")
724
725
726
727
728
729
730
731
732
733
734
735
736
            num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
            num_decode_tokens = metadata_dict.pop("num_decode_tokens")
            batch_type = metadata_dict.pop("batch_type")

            # Create an attention metadata.
            prefill_attn_metadata = None
            decode_attn_metadata = None
            if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
                prefill_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
            else:
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
737
738
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
739
                selected_token_indices=selected_token_indices,
740
                categorized_sample_indices=None,
741
                num_prompts=0,
742
743
            )

744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            # if it is a mixed batch, decode attn_metadata is broadcasted
            # separately.
            if batch_type == BatchType.MIXED:
                metadata_dict = broadcast_tensor_dict(src=0)
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)

        attn_metadata = AttentionMetadata(
            num_prefills=num_prefills,
            slot_mapping=slot_mapping,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            prefill_metadata=prefill_attn_metadata,
            decode_metadata=decode_attn_metadata,
            kv_cache_dtype=self.kv_cache_dtype,
        )

761
        return (input_tokens, input_positions, attn_metadata,
762
763
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
764

765
766
767
    @torch.inference_mode()
    def execute_model(
        self,
768
        seq_group_metadata_list: List[SequenceGroupMetadata],
769
        kv_caches: List[torch.Tensor],
770
    ) -> Optional[SamplerOutput]:
771
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
772
773
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
774
775
776
777

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

778
779
780
781
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
782
783
784
785
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
786
787
788
789
790
791
792
793
794
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
795

796
797
798
799
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
800
        if not self.is_driver_worker:
801
802
            return None

803
804
        # Sample the next token.
        output = self.model.sample(
805
            logits=logits,
806
807
            sampling_metadata=sampling_metadata,
        )
808

809
810
811
812
813
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
814
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
815
816
817
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

818
819
820
821
822
823
824
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
825
            assert self.lora_manager is not None
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
                        lora_local_path="/not/a/real/path",
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
841

842
843
844
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
845
846
847
848
849
850
851
852
853
854
855
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
856
857
858
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
859
860
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
861
862
863
864
865
866
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
867
868
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
869
                multi_modal_data=fake_multi_modal_input,
870
871
872
873
874
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
875
        kv_caches = [None] * num_layers
876
        self.execute_model(seqs, kv_caches)
877
        torch.cuda.synchronize()
878
879
        return

880
    def remove_all_loras(self):
881
882
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
883
        self.lora_manager.remove_all_loras()
884

885
    def set_active_loras(self, lora_requests: Set[LoRARequest],
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

906
    @torch.inference_mode()
907
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
908
909
910
911
912
913
914
915
916
917
918
919
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
920
921
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
922
        self.pynccl_backend = pynccl_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
923

924
925
926
927
928
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
929
930
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
931
932
933
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
934
935
936
937
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
938
939
940
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
941
        slot_mapping.fill_(_PAD_SLOT_ID)
942
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
943
944
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

945
946
947
948
949
950
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
951
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
952
953
        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
Woosuk Kwon's avatar
Woosuk Kwon committed
954
955
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
956
        # to PyTorch or pynccl if it is disabled or not supported.
957
        with custom_all_reduce.capture():
958
959
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
960
            for batch_size in reversed(batch_size_capture_list):
961
                # Create dummy attn_metadata.
962
                decode_metadata = self.attn_backend.make_metadata(
963
                    is_prompt=False,
964
965
966
967
                    seq_lens=None,
                    seq_lens_tensor=seq_lens[:batch_size],
                    max_query_len=None,
                    max_seq_len=self.max_seq_len_to_capture,
968
969
                    subquery_start_loc=None,
                    seq_start_loc=None,
970
                    context_lens_tensor=None,
971
972
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
973
974
975
976
977
978
979
980
                )
                attn_metadata = AttentionMetadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
                    prefill_metadata=None,
                    decode_metadata=decode_metadata,
981
                    kv_cache_dtype=self.kv_cache_dtype,
982
                )
983

984
985
986
987
988
989
990
991
992
993
994
995
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
996
                    attn_metadata,
997
                    memory_pool=self.graph_memory_pool,
998
                )
999
1000
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
1001
1002
1003
1004

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
1005
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
1006

Woosuk Kwon's avatar
Woosuk Kwon committed
1007
    def __del__(self) -> None:
1008
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
1009
1010
1011
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
1012
1013
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
1014
        self.graph_runners.clear()
1015
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
1016

1017
1018
1019
1020
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1021
1022
1023
1024
1025
1026
1027
1028

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1029
1030
1031
1032
1033
1034
1035
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1036
1037
1038
1039
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1040
1041
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1042
        memory_pool,
1043
        **kwargs,
1044
    ) -> None:
1045
        assert self._graph is None
1046
1047
1048
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1049
        with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
1050
            self.model(
1051
1052
1053
                input_ids,
                positions,
                kv_caches,
1054
                attn_metadata,
1055
                **kwargs,
1056
1057
1058
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
1059
1060
1061
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
1062
1063
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool):  # noqa: SIM117
1064
            with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
1065
1066
1067
1068
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
1069
                    attn_metadata,
1070
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
1071
1072
1073
                )
        torch.cuda.synchronize()

1074
1075
1076
1077
1078
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
1079
            "slot_mapping": attn_metadata.slot_mapping,
1080
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
1081
            "block_tables": attn_metadata.decode_metadata.block_tables,
1082
1083
1084
1085
1086
1087
1088
1089
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1090
1091
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1092
        **kwargs,
1093
1094
1095
1096
1097
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1098
1099
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1100
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1101
                                                 non_blocking=True)
1102
1103
        self.input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
1104
1105
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1106
1107
1108
1109
1110
1111
1112
1113
1114
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1115

1116
@contextlib.contextmanager
1117
1118
1119
1120
def _maybe_pynccl():
    if pynccl_utils.is_initialized(
    ) and not custom_all_reduce.is_initialized():
        with with_pynccl_for_all_reduce():
1121
1122
1123
1124
1125
            yield
    else:
        yield


1126
def _get_graph_batch_size(batch_size: int) -> int:
1127
1128
1129
1130
1131
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1132
1133
1134
1135
1136
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1137
1138
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input