model_runner.py 44.4 KB
Newer Older
1
import time
2
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
3

4
import numpy as np
5
import torch
6
import torch.nn as nn
7

8
from vllm.attention import AttentionMetadata, get_attn_backend
9
10
11
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
12
from vllm.distributed import broadcast_tensor_dict
13
from vllm.distributed.communication_op import graph_capture
14
from vllm.logger import init_logger
15
16
17
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
18
from vllm.model_executor import SamplingMetadata
19
from vllm.model_executor.model_loader import get_model
20
from vllm.sampling_params import SamplingParams
21
22
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
23
24
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
                        is_pin_memory_available, make_tensor_with_pad)
25
26
27
28

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
29
LORA_WARMUP_RANK = 8
30
31
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
32
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
33
34
35
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
36
37


38
39
40
41
class ModelInput(NamedTuple):
    input_tokens: torch.Tensor
    input_positions: torch.Tensor
    attn_metadata: Optional[AttentionMetadata]
42
43
    seq_lens: List[int]
    query_lens: List[int]
44
    lora_mapping: Optional[LoRAMapping]
45
46
    lora_requests: Set[LoRARequest]
    multi_modal_input: Optional[torch.Tensor]
47
48
49
50
    slot_mapping: torch.Tensor
    num_prefill_tokens: int
    num_decode_tokens: int
    num_prefills: int
51
52

    @classmethod
53
54
55
56
    def empty(cls, device):
        return ModelInput(
            input_tokens=torch.empty(0, device=device),
            input_positions=torch.empty(0, device=device),
57
            attn_metadata=None,
58
59
            seq_lens=[],
            query_lens=[],
60
            lora_mapping=None,
61
62
            lora_requests=set(),
            multi_modal_input=None,
63
64
65
66
            slot_mapping=torch.empty(0, device=device),
            num_prefill_tokens=0,
            num_decode_tokens=0,
            num_prefills=0,
67
68
69
        )


70
71
72
73
74
75
76
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
77
        device_config: DeviceConfig,
78
        cache_config: CacheConfig,
79
        load_config: LoadConfig,
80
        lora_config: Optional[LoRAConfig],
81
        kv_cache_dtype: Optional[str] = "auto",
82
        is_driver_worker: bool = False,
83
        vision_language_config: Optional[VisionLanguageConfig] = None,
84
85
86
87
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
88
89
        self.device_config = device_config
        self.cache_config = cache_config
90
        self.lora_config = lora_config
91
        self.load_config = load_config
92
        self.is_driver_worker = is_driver_worker
93
        self.vision_language_config = vision_language_config
94

95
        self.device = self.device_config.device
96
        self.pin_memory = is_pin_memory_available()
97

98
99
100
101
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
102
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
103
104
105
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
        # When using CUDA graph, the input block tables must be padded to
106
        # max_seq_len_to_capture. However, creating the block table in
107
108
109
110
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
111
112
113
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
114
115
116
117
118
119
120
121
122
        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
        )
123

124
        # Lazy initialization
125
        self.model: nn.Module  # Set after load_model
126
127
        # Set if the backend is flashinfer.
        self.flashinfer_workspace_buffer: torch.Tensor
128
129
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
130

131
    def load_model(self) -> None:
132
        with CudaMemoryProfiler() as m:
133
            self.model = get_model(
134
135
136
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
137
138
139
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
140
                scheduler_config=self.scheduler_config,
141
                cache_config=self.cache_config,
142
            )
143
144

        self.model_memory_usage = m.consumed_memory
145
146
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
147
148

        if self.lora_config:
149
150
151
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
152
153
154
155
156
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
157
158
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
159
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
160
161
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
162
            self.model = self.lora_manager.create_lora_manager(self.model)
163

164
165
166
167
168
169
170
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
171
172
173
174
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
175
            else:
176
177
178
179
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
180
        elif self.model_config.quantization_param_path is not None:
181
182
183
            logger.warning("KV cache scaling factors provided, "
                           "but the KV cache data type is not FP8. "
                           "KV cache scaling factors will not be used.")
184

185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

199
200
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
201
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
202

203
    def _prepare_model_input(
204
205
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
206
207
208
209
210
211
212
213
214
215
216
217
218
    ) -> ModelInput:
        """Prepare the model input based on a given sequence group.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
219
220
221
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
222
223
224
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
225

226
        seq_lens: List[int] = []
227
228
        prefill_seq_lens: List[int] = []
        decode_seq_lens: List[int] = []
229
        context_lens: List[int] = []
230
        query_lens: List[int] = []
231
        block_tables: List[List[int]] = []
232
233
234
235
236
        multi_modal_input_list: List[torch.Tensor] = []
        decode_only = True
        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = 0
237

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        # The following fields are only for flashinfer
        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        paged_kv_last_page_len: List[int] = []

256
        if len(seq_group_metadata_list) == 0:
257
            return ModelInput.empty(self.device)
258

259
260
        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
261
            is_prompt = seq_group_metadata.is_prompt
262

263
            for seq_id in seq_ids:
264
265
266
267
268
269
270
271
272
                computed_block_nums = seq_group_metadata.computed_block_nums
                if (self.scheduler_config is not None
                        and self.scheduler_config.chunked_prefill_enabled
                        and not (computed_block_nums is None
                                 or computed_block_nums == [])):
                    raise RuntimeError(
                        "chunked prefill cannot be used with prefix caching "
                        "now.")

273
                seq_data = seq_group_metadata.seq_data[seq_id]
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                else:
                    # get_num_computed_tokens is incorrect for spec decoding.
                    # So, we should have a special logic here.
                    # TODO(sang): Fix it.
                    context_len = seq_data.get_len() - 1

                seq_len = min(
                    seq_data.get_len(),
                    context_len + seq_group_metadata.token_chunk_size)
                if is_prompt:
                    tokens = seq_data.get_token_ids()[context_len:seq_len]
                else:
                    # Optimization. get_token_ids requires the entire copy of
                    # tokens.
                    tokens = [seq_data.get_last_token_id()]
291

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
                # Prefix cache was hit.
                # Prefix is not supported with sliding_window
                prefix_cache_hit = (computed_block_nums is not None
                                    and len(computed_block_nums) > 0
                                    and self.sliding_window is None
                                    and is_prompt)

                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                if prefix_cache_hit:
                    assert computed_block_nums is not None
                    context_len = len(computed_block_nums) * self.block_size
                    tokens = tokens[context_len:]
                    if self.attn_backend.get_name() == "flash-attn":
                        # NOTE(woosuk): For flash-attn, the block table should
                        # include the entries for the incoming prefill tokens.
                        # TODO(woosuk): This is a temporary fix. We should
                        # provide a unified interface for different backends.
                        block_table = seq_group_metadata.block_tables[seq_id]
                    else:
                        block_table = computed_block_nums
                elif (self.scheduler_config.chunked_prefill_enabled
                      or not is_prompt):
                    if seq_group_metadata.block_tables is not None:
                        # chunked prefill or decode
                        block_table = seq_group_metadata.block_tables[seq_id]
                        if self.sliding_window is not None:
                            # chunked prefill doesn't support sliding window.
                            assert (not self.scheduler_config.
                                    chunked_prefill_enabled)
                            sliding_window_blocks = (self.sliding_window //
                                                     self.block_size)
                            block_table = block_table[-sliding_window_blocks:]

                        if self.attn_backend.get_name() == "flashinfer":
                            paged_kv_indices.extend(block_table)
                            paged_kv_indptr.append(paged_kv_indptr[-1] +
                                                   len(block_table))
                            last_page_len = seq_data.get_len(
                            ) % self.block_size
                            if last_page_len == 0:
                                last_page_len = self.block_size
                            paged_kv_last_page_len.append(last_page_len)
                    else:
                        # Only happens when memory profiling runs.
                        block_table = []
                else:
                    # Prefill without chunked prefill or memory profiling.
                    block_table = []
                block_tables.append(block_table)

                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                if (self.sliding_window is not None and not is_prompt):
                    seq_len = min(seq_len, self.sliding_window)
                    context_len = seq_len - 1
350

351
                seq_lens.append(seq_len)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
                context_lens.append(context_len)
                query_len = seq_len - context_len
                query_lens.append(query_len)
                input_tokens.extend(tokens)
                input_positions.extend(list(range(context_len, seq_len)))
                lora_id = seq_group_metadata.lora_int_id

                if is_prompt:
                    assert len(seq_ids) == 1
                    num_prefills += 1
                    num_prefill_tokens += len(tokens)
                    decode_only = False
                    prefill_seq_lens.append(seq_len)
                else:
                    assert query_len == 1, (
                        "seq_len: {}, context_len: {}, query_len: {}".format(
                            seq_len, context_len, query_len))
                    num_decode_tokens += query_len
                    decode_seq_lens.append(seq_len)

                if lora_id > 0:
                    lora_requests.add(seq_group_metadata.lora_request)

                lora_index_mapping += [lora_id] * (seq_len - context_len)
                lora_prompt_mapping.extend(
                    [lora_id] *
                    (seq_len -
                     context_len if seq_group_metadata.sampling_params
                     and seq_group_metadata.sampling_params.prompt_logprobs
                     else 1))

                if seq_group_metadata.multi_modal_data:
                    multi_modal_input_list.append(
                        seq_group_metadata.multi_modal_data.data)

                if _is_block_tables_empty(seq_group_metadata.block_tables):
                    # During memory profiling, the block tables are not
                    # initialized yet. In this case, we just use a dummy
                    # slot mapping.
                    # In embeddings, the block tables are {seq_id: None}.
                    slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
                    continue
394

395
                # Compute the slot mapping.
396
397
                block_table = seq_group_metadata.block_tables[seq_id]

398
399
400
401
402
403
404
                # Mask the [0, start_idx) tokens of the prompt with
                # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
                # sliding_window). For example, if the prompt len is 10,
                # sliding window is 8, and block size is 4, the first two
                # tokens are masked and the slot mapping will be
                # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
                start_idx = 0
405
                if self.sliding_window is not None:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
                    if is_prompt:
                        assert context_len == 0, (
                            "Prefix caching is currently not supported with "
                            "sliding window attention")
                    # It is an optimization. When it is decoding, it is always
                    # 0. When prefill, we use it to not write slots to kv cache
                    # to save memory.
                    start_idx = max(0, query_len - self.sliding_window)

                for i in range(context_len, seq_len):
                    if i < start_idx:
                        slot_mapping.append(_PAD_SLOT_ID)
                        continue

                    block_number = block_table[i // self.block_size]
                    block_offset = i % self.block_size
                    slot = block_number * self.block_size + block_offset
                    slot_mapping.append(slot)
424

425
426
427
428
        batch_size = len(input_tokens)
        max_query_len = max(query_lens)
        max_prefill_seq_len = max(prefill_seq_lens, default=0)
        max_decode_seq_len = max(decode_seq_lens, default=0)
429

430
        # If cuda graph can be used, pad tensors accordingly.
431
        # See `capture_model` API for more details.
432
433
434
435
436
        # vLLM uses cuda graph only for decoding requests.
        use_captured_graph = (
            decode_only and not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_decode_seq_len <= self.max_seq_len_to_capture)
437
438
439
440
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
441
442
443
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
444
                seq_lens.append(1)
445
                block_tables.append([])
446
                lora_index_mapping.append(0)
447
            batch_size = graph_batch_size
448
            num_decode_tokens = batch_size
449
450
451
452
453
454
455
456

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
457
            block_tables = torch.tensor(input_block_tables, device=self.device)
458
        else:
459
460
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
461
            block_tables = make_tensor_with_pad(
462
                block_tables,
463
                max_len=max_block_table_len,
464
465
                pad=0,
                dtype=torch.int,
466
                device=self.device,
467
            )
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
                                           device=self.device)

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        query_lens_tensor = torch.tensor(query_lens,
                                         dtype=torch.long,
                                         device=self.device)
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=self.device)

        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])

        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])

        input_tokens_tensor = torch.tensor(input_tokens,
                                           dtype=torch.long,
                                           device=self.device)
        input_positions_tensor = torch.tensor(input_positions,
                                              dtype=torch.long,
                                              device=self.device)
        slot_mapping_tensor = torch.tensor(slot_mapping,
                                           dtype=torch.long,
                                           device=self.device)
519

520
        if self.attn_backend.get_name() == "flashinfer":
521
522
523
524
525
            if not hasattr(self, "flashinfer_workspace_buffer"):
                # Allocate 16MB workspace buffer
                # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
                self.flashinfer_workspace_buffer = torch.empty(
                    16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
526
            paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
527
528
                                                  dtype=torch.int,
                                                  device=self.device)
529
530
531
532
533
            paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
                                                   dtype=torch.int,
                                                   device=self.device)
            paged_kv_last_page_len_tensor = torch.tensor(
                paged_kv_last_page_len, dtype=torch.int, device=self.device)
534
535
536
            kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
                                                      self.model_config.dtype)
            attn_metadata = self.attn_backend.make_metadata(
537
538
539
540
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
541
                use_cuda_graph=False,
542
543
                max_prefill_seq_len=max_prefill_seq_len,
                block_tables=block_tables,
544
                workspace_buffer=self.flashinfer_workspace_buffer,
545
546
547
                paged_kv_indptr=paged_kv_indptr_tensor,
                paged_kv_indices=paged_kv_indices_tensor,
                paged_kv_last_page_len=paged_kv_last_page_len_tensor,
548
549
550
551
552
                num_qo_heads=self.model_config.get_num_attention_heads(
                    self.parallel_config),
                num_kv_heads=self.model_config.get_num_kv_heads(
                    self.parallel_config),
                head_dim=self.model_config.get_head_size(),
553
554
                page_size=16,
                seq_start_loc=seq_start_loc,
555
556
557
                data_type=kv_cache_dtype)
        else:
            attn_metadata = self.attn_backend.make_metadata(
558
559
560
561
562
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
                seq_lens=seq_lens,
563
                seq_lens_tensor=seq_lens_tensor,
564
565
566
567
568
569
                max_query_len=max_query_len,
                max_prefill_seq_len=max_prefill_seq_len,
                max_decode_seq_len=max_decode_seq_len,
                query_start_loc=query_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
570
571
572
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
573
574
575
576
577
578
579
580
581
582
583
584

        if self.lora_config:
            lora_mapping = LoRAMapping(
                lora_index_mapping,
                lora_prompt_mapping,
            )
        else:
            lora_mapping = None

        return ModelInput(
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
585
            attn_metadata=attn_metadata,
586
587
588
            seq_lens=seq_lens,
            query_lens=query_lens,
            lora_mapping=lora_mapping,
589
            lora_requests=lora_requests,
590
591
592
593
594
            multi_modal_input=multi_modal_input,
            slot_mapping=slot_mapping_tensor,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
595
        )
596

597
598
    def prepare_input_tensors(
        self,
599
        seq_group_metadata_list: List[SequenceGroupMetadata],
600
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
601
               Set[LoRARequest], LoRAMapping, torch.Tensor]:
602
603
        if self.is_driver_worker:
            # Prepare input tensors.
604
605
606
            (
                input_tokens,
                input_positions,
607
                attn_metadata,
608
609
                seq_lens,
                query_lens,
610
                lora_mapping,
611
612
613
                lora_requests,
                multi_modal_input,
                slot_mapping,
614
615
616
617
                num_prefill_tokens,
                num_decode_tokens,
                num_prefills,
            ) = self._prepare_model_input(seq_group_metadata_list)
618
            sampling_metadata = SamplingMetadata.prepare(
619
620
                seq_group_metadata_list, seq_lens, query_lens, self.device,
                self.pin_memory)
621

622
623
624
625
626
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
627
628
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
629
                "multi_modal_input": multi_modal_input,
630
631
632
633
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
634
            }
635
636
            if attn_metadata:
                metadata_dict.update(attn_metadata.asdict_zerocopy())
637
            broadcast_tensor_dict(metadata_dict, src=0)
638
        else:
639
            metadata_dict = broadcast_tensor_dict(src=0)
640
641
642
643
644
645
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
646
            multi_modal_input = metadata_dict.pop("multi_modal_input")
647
648
            if metadata_dict:
                attn_metadata = self.attn_backend.make_metadata(
649
650
                    **metadata_dict)
            else:
651
                attn_metadata = None
652
653
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
654
                selected_token_indices=selected_token_indices,
655
                categorized_sample_indices=None,
656
                num_prompts=0,
657
658
            )

659
        return (input_tokens, input_positions, attn_metadata,
660
661
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
662

663
664
665
    @torch.inference_mode()
    def execute_model(
        self,
666
        seq_group_metadata_list: List[SequenceGroupMetadata],
667
        kv_caches: List[torch.Tensor],
668
    ) -> Optional[SamplerOutput]:
669
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
670
671
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
672
673
674
675

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

676
677
678
679
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
680
681
682
683
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
684
685
686
687
688
689
690
691
692
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
693

694
695
696
697
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
698
        if not self.is_driver_worker:
699
700
            return None

701
702
        # Sample the next token.
        output = self.model.sample(
703
            logits=logits,
704
705
            sampling_metadata=sampling_metadata,
        )
706

707
708
709
710
711
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
712
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
713
714
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
715
716
717
718
719
720
721
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
722
            assert self.lora_manager is not None
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
                        lora_local_path="/not/a/real/path",
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
738

739
740
741
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
742
743
744
745
746
747
748
749
750
751
752
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
753
754
755
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
756
757
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
758
759
760
761
762
763
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
764
765
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
766
                multi_modal_data=fake_multi_modal_input,
767
768
769
770
771
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
772
        kv_caches = [None] * num_layers
773
        self.execute_model(seqs, kv_caches)
774
        torch.cuda.synchronize()
775
776
        return

777
    def remove_all_loras(self):
778
779
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
780
        self.lora_manager.remove_all_loras()
781

782
    def set_active_loras(self, lora_requests: Set[LoRARequest],
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

803
    @torch.inference_mode()
804
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
805
806
807
808
809
810
811
812
813
814
815
816
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
817
818
819
820
821
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
822
823
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
824
825
826
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
827
828
829
830
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
831
832
833
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
834
        slot_mapping.fill_(_PAD_SLOT_ID)
835
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
836
837
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

838
839
840
841
842
843
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

844
        with graph_capture() as graph_capture_context:
845
846
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
847
            for batch_size in reversed(batch_size_capture_list):
848
                # Create dummy attn_metadata.
849
850
851
852
853
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
854
855
856
                    seq_lens=None,
                    seq_lens_tensor=seq_lens[:batch_size],
                    max_query_len=None,
857
858
859
                    max_prefill_seq_len=0,
                    max_decode_seq_len=self.max_seq_len_to_capture,
                    query_start_loc=None,
860
                    seq_start_loc=None,
861
                    context_lens_tensor=None,
862
863
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
864
                )
865

866
867
868
869
870
871
872
873
874
875
876
877
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
878
                    attn_metadata,
879
                    memory_pool=self.graph_memory_pool,
880
                    stream=graph_capture_context.stream,
881
                )
882
883
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
884
885
886
887

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
888
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
889

Woosuk Kwon's avatar
Woosuk Kwon committed
890
    def __del__(self) -> None:
891
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
892
893
894
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
895
896
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
897
        self.graph_runners.clear()
898
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
899

900
901
902
903
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

904
905
906
907
908
909
910
911

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

912
913
914
915
916
917
918
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

919
920
921
922
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
923
924
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
925
926
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
927
        **kwargs,
928
    ) -> None:
929
        assert self._graph is None
930
931
932
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
933
934
935
936
937
938
939
940
941
942
943
944
945
        self.model(
            input_ids,
            positions,
            kv_caches,
            attn_metadata,
            **kwargs,
        )
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
            hidden_states = self.model(
946
947
948
                input_ids,
                positions,
                kv_caches,
949
                attn_metadata,
950
                **kwargs,
951
952
953
954
955
956
957
958
            )
        torch.cuda.synchronize()

        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
959
            "slot_mapping": attn_metadata.slot_mapping,
960
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
961
            "block_tables": attn_metadata.decode_metadata.block_tables,
962
963
964
965
966
967
968
969
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
970
971
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
972
        **kwargs,
973
974
975
976
977
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
978
979
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
980
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
981
                                                 non_blocking=True)
982
983
        self.input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
984
985
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
986
987
988
989
990
991
992
993
994
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

995

996
def _get_graph_batch_size(batch_size: int) -> int:
997
998
999
1000
1001
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1002
1003
1004
1005
1006
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1007
1008
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038


def _is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    if isinstance(block_tables, dict) and all(
            value is None for value in block_tables.values()):
        return True
    return False