model_runner.py 47.8 KB
Newer Older
1
import time
2
from enum import IntEnum
3
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

9
10
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
                            get_attn_backend)
11
12
13
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
14
from vllm.distributed import broadcast_tensor_dict
15
from vllm.distributed.communication_op import graph_capture, graph_mode
16
from vllm.logger import init_logger
17
18
19
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
20
from vllm.model_executor import SamplingMetadata
21
from vllm.model_executor.model_loader import get_model
22
from vllm.sampling_params import SamplingParams
23
24
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
25
26
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
                        is_pin_memory_available, make_tensor_with_pad)
27
28
29
30

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
31
LORA_WARMUP_RANK = 8
32
33
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
34
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
35
36
37
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
38
39


40
41
42
43
class PreparePromptMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadataPerStage]
44
45
    seq_lens: List[int]
    query_lens: List[int]
46
47
48
49
50
51
52
53
54
55
56
57
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    multi_modal_input: Optional[torch.Tensor]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PreparePromptMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
58
59
            seq_lens=[],
            query_lens=[],
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            multi_modal_input=None,
            slot_mapping=[],
        )


class PrepareDecodeMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadata]
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PrepareDecodeMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            slot_mapping=[],
        )


# How batches are constructed.
class BatchType(IntEnum):
    # Every batch is prefill.
    PREFILL = 0
    # Every batch is decode.
    DECODE = 1
    # Batch is a mixture of prefill and decode.
    MIXED = 2


100
101
102
103
104
105
106
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
107
        device_config: DeviceConfig,
108
        cache_config: CacheConfig,
109
        load_config: LoadConfig,
110
        lora_config: Optional[LoRAConfig],
111
        kv_cache_dtype: Optional[str] = "auto",
112
        is_driver_worker: bool = False,
113
        vision_language_config: Optional[VisionLanguageConfig] = None,
114
115
116
117
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
118
119
        self.device_config = device_config
        self.cache_config = cache_config
120
        self.lora_config = lora_config
121
        self.load_config = load_config
122
        self.is_driver_worker = is_driver_worker
123
        self.vision_language_config = vision_language_config
124

125
        self.device = self.device_config.device
126
        self.pin_memory = is_pin_memory_available()
127

128
129
130
131
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
132
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
133
134
135
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
        # When using CUDA graph, the input block tables must be padded to
136
        # max_seq_len_to_capture. However, creating the block table in
137
138
139
140
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
141
142
143
144
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
        self.attn_backend = get_attn_backend(self.model_config.dtype)
145

146
147
        # Lazy initialization
        self.model: torch.nn.Module  # Set after load_model
148
149
        # Set if the backend is flashinfer.
        self.flashinfer_workspace_buffer: torch.Tensor
150
151
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
152

153
    def load_model(self) -> None:
154
        with CudaMemoryProfiler() as m:
155
            self.model = get_model(
156
157
158
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
159
160
161
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
162
163
                scheduler_config=self.scheduler_config,
            )
164
165

        self.model_memory_usage = m.consumed_memory
166
167
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
168
169

        if self.lora_config:
170
171
172
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
173
174
175
176
177
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
178
179
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
180
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
181
182
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
183
            self.model = self.lora_manager.create_lora_manager(self.model)
184

185
186
187
188
189
190
191
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
192
193
194
195
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
196
            else:
197
198
199
200
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
201
        elif self.model_config.quantization_param_path is not None:
202
203
204
            logger.warning("KV cache scaling factors provided, "
                           "but the KV cache data type is not FP8. "
                           "KV cache scaling factors will not be used.")
205

206
207
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
208
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
209

210
211
212
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
213
    ) -> PreparePromptMetadata:
214
215
216
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
217
218
219
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
220

221
        seq_lens: List[int] = []
222
        context_lens: List[int] = []
223
        query_lens: List[int] = []
224
        prefix_block_tables: List[List[int]] = []
225
        multi_modal_input_list: List[torch.Tensor] = []
226

227
228
229
        if len(seq_group_metadata_list) == 0:
            return PreparePromptMetadata.empty()

230
231
232
233
234
235
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

236
237
238
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
239
240
                    and not (computed_block_nums is None
                             or computed_block_nums == [])):
241
242
243
244
245
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
246
            seq_data = seq_group_metadata.seq_data[seq_id]
247
            context_len = seq_data.get_num_computed_tokens()
248
249
            # We should use get_len here because in case of preemption
            # it contains output tokens.
250
251
252
            seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
            prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
            seq_lens.append(seq_len)
253
254
255
256
257

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
258
259
                context_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[context_len:]
260
                prefix_block_tables.append(computed_block_nums)
261
262
263
264
265
266
267
268
            elif self.scheduler_config.chunked_prefill_enabled:
                if seq_group_metadata.block_tables is not None:
                    # Prefill has chunked before.
                    block_table = seq_group_metadata.block_tables[seq_id]
                    prefix_block_tables.append(block_table)
                else:
                    # The first prefill.
                    prefix_block_tables.append([])
269
270
            else:
                prefix_block_tables.append([])
271
272
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
273
                assert context_len == 0
274

275
            # actual prompt lens
276
277
            context_lens.append(context_len)
            query_lens.append(seq_len - context_len)
278

279
            input_tokens.extend(prompt_tokens)
280
281
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
282
            input_positions.extend(list(range(context_len, seq_len)))
283
284
285
286
287
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

288
            lora_index_mapping += [lora_id] * (seq_len - context_len)
289
290
291
            lora_prompt_mapping.extend([lora_id] * (
                seq_len - context_len if seq_group_metadata.sampling_params
                and seq_group_metadata.sampling_params.prompt_logprobs else 1))
292

293
294
295
296
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

297
            if _is_block_tables_empty(seq_group_metadata.block_tables):
298
299
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
300
                # In embeddings, the block tables are {seq_id: None}.
301
                slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
302
303
304
305
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
306

307
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
308
            # where start_idx is max(0, seq_len - sliding_window).
309
310
311
312
313
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
314
                assert context_len == 0, (
315
316
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
317
                start_idx = max(0, seq_len - self.sliding_window)
318

319
            for i in range(context_len, seq_len):
320
                if i < start_idx:
321
                    slot_mapping.append(_PAD_SLOT_ID)
322
323
324
325
326
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
327
328
                slot_mapping.append(slot)

329
330
331
        max_query_len = max(query_lens)
        max_seq_len = max(seq_lens)
        assert max_query_len > 0
332

333
334
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
335
                                           device=self.device)
336
337
338
339
340
341
342
343
344
345

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

346
347
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
348
        block_tables = make_tensor_with_pad(
349
350
351
352
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
353
            device=self.device,
354
        )
355
356
357

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
358
359
360
361
        query_lens_tensor = torch.tensor(query_lens,
                                         dtype=torch.long,
                                         device=self.device)
        subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
362
363
364
                                         dtype=torch.int32,
                                         device=self.device)

365
366
367
368
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
369
370
371
                                    dtype=torch.int32,
                                    device=self.device)

372
        torch.cumsum(query_lens_tensor,
373
374
375
376
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

377
        torch.cumsum(seq_lens_tensor,
378
379
380
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
381

382
        if self.attn_backend.get_name() == "flashinfer":
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
            attn_metadata = self.attn_backend.make_metadata(
                is_prompt=True,
                use_cuda_graph=False,
                seq_start_loc=seq_start_loc,
                max_seq_len=max_seq_len,
                block_tables=block_tables)
        else:
            attn_metadata = self.attn_backend.make_metadata(
                is_prompt=True,
                seq_lens=seq_lens,
                seq_lens_tensor=seq_lens_tensor,
                max_query_len=max_query_len,
                max_seq_len=max_seq_len,
                subquery_start_loc=subquery_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
                block_tables=block_tables,
                use_cuda_graph=False,
            )
402
403
404
405
406

        return PreparePromptMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
407
408
            seq_lens=seq_lens,
            query_lens=query_lens,
409
410
411
412
413
414
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            multi_modal_input=multi_modal_input,
            slot_mapping=slot_mapping,
        )
415
416
417
418

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
419
    ) -> PrepareDecodeMetadata:
420
421
422
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
423
        seq_lens: List[int] = []
424
        block_tables: List[List[int]] = []
425
426
427
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        # The following fields are only for flashinfer
        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        paged_kv_last_page_len: List[int] = []

447
448
449
        if len(seq_group_metadata_list) == 0:
            return PrepareDecodeMetadata.empty()

450
451
        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
452
            assert seq_group_metadata.token_chunk_size == 1
453
454

            seq_ids = list(seq_group_metadata.seq_data.keys())
455
456
457
458
459
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

460
461
462
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
463
                input_tokens.append(generation_token)
464

465
466
                seq_len = seq_data.get_len()
                position = seq_len - 1
467
                input_positions.append(position)
468

469
                seq_len = seq_len if self.sliding_window is None else min(
470
                    seq_len, self.sliding_window)
471
                seq_lens.append(seq_len)
472

473
474
475
476
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
477
478
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
479
                lora_prompt_mapping.append(lora_id)
480
481
482
483
484
485
486

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

487
488
489
490
491
492
493
                paged_kv_indices.extend(block_table)
                paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
                last_page_len = seq_data.get_len() % self.block_size
                if last_page_len == 0:
                    last_page_len = self.block_size
                paged_kv_last_page_len.append(last_page_len)

494
495
496
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
497
        batch_size = len(input_tokens)
498
499
500
501
        max_seq_len = max(seq_lens)
        use_captured_graph = (not self.model_config.enforce_eager
                              and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
                              and max_seq_len <= self.max_seq_len_to_capture)
502
503
504
505
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
506
507
508
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
509
                seq_lens.append(1)
510
                block_tables.append([])
511
                lora_index_mapping.append(0)
512
513
            batch_size = graph_batch_size

514
515
516
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
517
518

        if use_captured_graph:
519
520
            # When using cuda-graph all these tensors should be
            # padded.
521
522
523
            assert seq_lens_tensor.shape[0] == len(input_tokens)
            assert seq_lens_tensor.shape[0] == len(input_positions)
            assert seq_lens_tensor.shape[0] == len(slot_mapping)
524

525
526
527
528
529
530
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
531
            block_tables = torch.tensor(input_block_tables, device=self.device)
532
        else:
533
534
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
535
            block_tables = make_tensor_with_pad(
536
                block_tables,
537
                max_len=max_block_table_len,
538
539
                pad=0,
                dtype=torch.int,
540
                device=self.device,
541
            )
542

543
        if self.attn_backend.get_name() == "flashinfer":
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
            if not hasattr(self, "flashinfer_workspace_buffer"):
                # Allocate 16MB workspace buffer
                # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
                self.flashinfer_workspace_buffer = torch.empty(
                    16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
            paged_kv_indptr = torch.tensor(paged_kv_indptr,
                                           dtype=torch.int,
                                           device=self.device)
            paged_kv_indices = torch.tensor(paged_kv_indices,
                                            dtype=torch.int,
                                            device=self.device)
            paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
                                                  dtype=torch.int,
                                                  device=self.device)
            kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
                                                      self.model_config.dtype)

            attn_metadata = self.attn_backend.make_metadata(
                is_prompt=False,
                use_cuda_graph=False,
                workspace_buffer=self.flashinfer_workspace_buffer,
                paged_kv_indptr=paged_kv_indptr,
                paged_kv_indices=paged_kv_indices,
                paged_kv_last_page_len=paged_kv_last_page_len,
                num_qo_heads=self.model_config.get_num_attention_heads(
                    self.parallel_config),
                num_kv_heads=self.model_config.get_num_kv_heads(
                    self.parallel_config),
                head_dim=self.model_config.get_head_size(),
                page_size=self.block_size,
                data_type=kv_cache_dtype)
        else:
            attn_metadata = self.attn_backend.make_metadata(
                is_prompt=False,
                seq_lens=None,
                seq_lens_tensor=seq_lens_tensor,
                max_query_len=None,
                max_seq_len=max_seq_len,
                subquery_start_loc=None,
                seq_start_loc=None,
                context_lens_tensor=None,
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
588
589
590
591
592
593
594
595
596
        return PrepareDecodeMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            slot_mapping=slot_mapping,
        )
597

598
599
    def prepare_input_tensors(
        self,
600
        seq_group_metadata_list: List[SequenceGroupMetadata],
601
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
602
               Set[LoRARequest], LoRAMapping, torch.Tensor]:
603
        if self.is_driver_worker:
604
605
606
607
608
609
610
611
            prefill_reqs = []
            decode_reqs = []
            for seq_group_meta in seq_group_metadata_list:
                if seq_group_meta.is_prompt:
                    prefill_reqs.append(seq_group_meta)
                else:
                    decode_reqs.append(seq_group_meta)

612
            # Prepare input tensors.
613
614
615
616
            (
                input_tokens,
                input_positions,
                prefill_attn_metadata,
617
618
                seq_lens,
                query_lens,
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
                lora_index_mapping,
                lora_prompt_mapping,
                lora_requests,
                multi_modal_input,
                slot_mapping,
            ) = self._prepare_prompt(prefill_reqs)
            (
                decode_input_tokens,
                decode_input_positions,
                decode_attn_metadata,
                decode_lora_index_mapping,
                decode_lora_prompt_mapping,
                decode_lora_requests,
                decode_slot_mapping,
            ) = self._prepare_decode(decode_reqs)
634
            sampling_metadata = SamplingMetadata.prepare(
635
636
                seq_group_metadata_list, seq_lens, query_lens, self.device,
                self.pin_memory)
637

638
639
640
            if not self.scheduler_config.chunked_prefill_enabled:
                assert (len(prefill_reqs) and len(decode_reqs)) == 0

641
            num_prefills = len(seq_lens)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
            num_prefill_tokens = len(input_tokens)
            num_decode_tokens = len(decode_input_tokens)

            # Coalesce tensors. Note that attn_metadata is currently not
            # coalesced for simplicity.
            input_tokens.extend(decode_input_tokens)
            input_positions.extend(decode_input_positions)
            slot_mapping.extend(decode_slot_mapping)
            lora_index_mapping.extend(decode_lora_index_mapping)
            lora_prompt_mapping.extend(decode_lora_prompt_mapping)
            lora_requests.update(decode_lora_requests)

            input_tokens = torch.tensor(input_tokens,
                                        dtype=torch.long,
                                        device=self.device)
            input_positions = torch.tensor(input_positions,
                                           dtype=torch.long,
                                           device=self.device)
            slot_mapping = torch.tensor(slot_mapping,
                                        dtype=torch.long,
                                        device=self.device)

664
665
            if self.lora_config:
                lora_mapping = LoRAMapping(
666
                    lora_index_mapping,
667
668
669
670
671
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

672
            # Broadcast the metadata.
673
674
675
676
677
678
679
680
681
682
            # If batch contains both prefill and decode, it sends 2 broadcasts.
            # If it only contains 1 type, it triggers a single broadcast.
            if (prefill_attn_metadata is not None
                    and decode_attn_metadata is not None):
                batch_type = BatchType.MIXED
            elif prefill_attn_metadata is not None:
                batch_type = BatchType.PREFILL
            else:
                batch_type = BatchType.DECODE

683
684
685
686
687
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
688
689
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
690
                "multi_modal_input": multi_modal_input,
691
692
693
694
695
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
                "batch_type": batch_type,
696
            }
697
698
699
            if prefill_attn_metadata is not None:
                metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
            else:
700
                assert decode_attn_metadata is not None
701
                metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
702
            broadcast_tensor_dict(metadata_dict, src=0)
703
704
705
706
707
708
709
710

            # Broadcast decode attn metadata for mixed batch type.
            # The additional broadcast costs 300us overhead on 4 A10 GPUs.
            # We can potentially reduce the overhead by coelescing tensors.
            if batch_type == BatchType.MIXED:
                assert decode_attn_metadata is not None
                metadata_dict = decode_attn_metadata.asdict_zerocopy()
                broadcast_tensor_dict(metadata_dict, src=0)
711
        else:
712
            metadata_dict = broadcast_tensor_dict(src=0)
713
714
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
715
716
            slot_mapping = metadata_dict.pop("slot_mapping")
            num_prefills = metadata_dict.pop("num_prefills")
717
718
719
720
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
721
            multi_modal_input = metadata_dict.pop("multi_modal_input")
722
723
724
725
726
727
728
729
730
731
732
733
734
            num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
            num_decode_tokens = metadata_dict.pop("num_decode_tokens")
            batch_type = metadata_dict.pop("batch_type")

            # Create an attention metadata.
            prefill_attn_metadata = None
            decode_attn_metadata = None
            if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
                prefill_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
            else:
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
735
736
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
737
                selected_token_indices=selected_token_indices,
738
                categorized_sample_indices=None,
739
                num_prompts=0,
740
741
            )

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            # if it is a mixed batch, decode attn_metadata is broadcasted
            # separately.
            if batch_type == BatchType.MIXED:
                metadata_dict = broadcast_tensor_dict(src=0)
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)

        attn_metadata = AttentionMetadata(
            num_prefills=num_prefills,
            slot_mapping=slot_mapping,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            prefill_metadata=prefill_attn_metadata,
            decode_metadata=decode_attn_metadata,
            kv_cache_dtype=self.kv_cache_dtype,
        )

759
        return (input_tokens, input_positions, attn_metadata,
760
761
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
762

763
764
765
    @torch.inference_mode()
    def execute_model(
        self,
766
        seq_group_metadata_list: List[SequenceGroupMetadata],
767
        kv_caches: List[torch.Tensor],
768
    ) -> Optional[SamplerOutput]:
769
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
770
771
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
772
773
774
775

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

776
777
778
779
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
780
781
782
783
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
784
785
786
787
788
789
790
791
792
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
793

794
795
796
797
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
798
        if not self.is_driver_worker:
799
800
            return None

801
802
        # Sample the next token.
        output = self.model.sample(
803
            logits=logits,
804
805
            sampling_metadata=sampling_metadata,
        )
806

807
808
809
810
811
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
812
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
813
814
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
815
816
817
818
819
820
821
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
822
            assert self.lora_manager is not None
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
                        lora_local_path="/not/a/real/path",
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
838

839
840
841
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
842
843
844
845
846
847
848
849
850
851
852
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
853
854
855
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
856
857
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
858
859
860
861
862
863
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
864
865
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
866
                multi_modal_data=fake_multi_modal_input,
867
868
869
870
871
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
872
        kv_caches = [None] * num_layers
873
        self.execute_model(seqs, kv_caches)
874
        torch.cuda.synchronize()
875
876
        return

877
    def remove_all_loras(self):
878
879
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
880
        self.lora_manager.remove_all_loras()
881

882
    def set_active_loras(self, lora_requests: Set[LoRARequest],
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

903
    @torch.inference_mode()
904
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
905
906
907
908
909
910
911
912
913
914
915
916
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
917
918
919
920
921
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
922
923
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
924
925
926
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
927
928
929
930
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
931
932
933
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
934
        slot_mapping.fill_(_PAD_SLOT_ID)
935
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
936
937
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

938
939
940
941
942
943
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

944
        with graph_capture():
945
946
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
947
            for batch_size in reversed(batch_size_capture_list):
948
                # Create dummy attn_metadata.
949
                decode_metadata = self.attn_backend.make_metadata(
950
                    is_prompt=False,
951
952
953
954
                    seq_lens=None,
                    seq_lens_tensor=seq_lens[:batch_size],
                    max_query_len=None,
                    max_seq_len=self.max_seq_len_to_capture,
955
956
                    subquery_start_loc=None,
                    seq_start_loc=None,
957
                    context_lens_tensor=None,
958
959
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
960
961
962
963
964
965
966
967
                )
                attn_metadata = AttentionMetadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
                    prefill_metadata=None,
                    decode_metadata=decode_metadata,
968
                    kv_cache_dtype=self.kv_cache_dtype,
969
                )
970

971
972
973
974
975
976
977
978
979
980
981
982
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
983
                    attn_metadata,
984
                    memory_pool=self.graph_memory_pool,
985
                )
986
987
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
988
989
990
991

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
992
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
993

Woosuk Kwon's avatar
Woosuk Kwon committed
994
    def __del__(self) -> None:
995
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
996
997
998
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
999
1000
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
1001
        self.graph_runners.clear()
1002
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
1003

1004
1005
1006
1007
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1008
1009
1010
1011
1012
1013
1014
1015

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1016
1017
1018
1019
1020
1021
1022
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1023
1024
1025
1026
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1027
1028
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1029
        memory_pool,
1030
        **kwargs,
1031
    ) -> None:
1032
        assert self._graph is None
1033
1034
1035
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1036
        with graph_mode():
Woosuk Kwon's avatar
Woosuk Kwon committed
1037
            self.model(
1038
1039
1040
                input_ids,
                positions,
                kv_caches,
1041
                attn_metadata,
1042
                **kwargs,
1043
1044
1045
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
1046
1047
1048
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
1049
1050
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool):  # noqa: SIM117
1051
            with graph_mode():
Woosuk Kwon's avatar
Woosuk Kwon committed
1052
1053
1054
1055
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
1056
                    attn_metadata,
1057
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
1058
1059
1060
                )
        torch.cuda.synchronize()

1061
1062
1063
1064
1065
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
1066
            "slot_mapping": attn_metadata.slot_mapping,
1067
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
1068
            "block_tables": attn_metadata.decode_metadata.block_tables,
1069
1070
1071
1072
1073
1074
1075
1076
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1077
1078
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1079
        **kwargs,
1080
1081
1082
1083
1084
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1085
1086
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1087
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1088
                                                 non_blocking=True)
1089
1090
        self.input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
1091
1092
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1093
1094
1095
1096
1097
1098
1099
1100
1101
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1102

1103
def _get_graph_batch_size(batch_size: int) -> int:
1104
1105
1106
1107
1108
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1109
1110
1111
1112
1113
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1114
1115
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145


def _is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    if isinstance(block_tables, dict) and all(
            value is None for value in block_tables.values()):
        return True
    return False