"tests/vscode:/vscode.git/clone" did not exist on "4054202359a950781f067cfc82b8a57350f28962"
model_runner.py 45 KB
Newer Older
1
import contextlib
2
import time
3
4
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
5

6
import numpy as np
7
import torch
8
import torch.nn as nn
9

10
11
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
                            get_attn_backend)
12
13
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VisionLanguageConfig)
14
15
16
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
                                                   pynccl_utils)
17
from vllm.logger import init_logger
18
19
20
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
21
from vllm.model_executor import SamplingMetadata
22
from vllm.model_executor.model_loader import get_model
23
from vllm.sampling_params import SamplingParams
24
25
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
26
27
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available,
                        make_tensor_with_pad)
28
29
30
31

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
32
LORA_WARMUP_RANK = 8
33
34
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
35
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
36
37
38
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
39
40


41
42
43
44
class PreparePromptMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadataPerStage]
45
46
    seq_lens: List[int]
    query_lens: List[int]
47
48
49
50
51
52
53
54
55
56
57
58
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    multi_modal_input: Optional[torch.Tensor]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PreparePromptMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
59
60
            seq_lens=[],
            query_lens=[],
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            multi_modal_input=None,
            slot_mapping=[],
        )


class PrepareDecodeMetadata(NamedTuple):
    input_tokens: List[int]
    input_positions: List[int]
    attn_metadata: Optional[AttentionMetadata]
    lora_index_mapping: List[int]
    lora_prompt_mapping: List[int]
    lora_requests: Set[LoRARequest]
    slot_mapping: List[int]

    @classmethod
    def empty(cls):
        return PrepareDecodeMetadata(
            input_tokens=[],
            input_positions=[],
            attn_metadata=None,
            lora_index_mapping=[],
            lora_prompt_mapping=[],
            lora_requests=set(),
            slot_mapping=[],
        )


# How batches are constructed.
class BatchType(IntEnum):
    # Every batch is prefill.
    PREFILL = 0
    # Every batch is decode.
    DECODE = 1
    # Batch is a mixture of prefill and decode.
    MIXED = 2


101
102
103
104
105
106
107
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
108
        device_config: DeviceConfig,
109
        load_config: LoadConfig,
110
        lora_config: Optional[LoRAConfig],
111
        kv_cache_dtype: Optional[str] = "auto",
112
        is_driver_worker: bool = False,
113
        vision_language_config: Optional[VisionLanguageConfig] = None,
114
115
116
117
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
118
        self.lora_config = lora_config
119
        self.load_config = load_config
120
        self.is_driver_worker = is_driver_worker
121

Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
124
125
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
126
127
128
129
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

130
131
        # Set after load_model.
        self.lora_manager: LRUCacheWorkerLoRAManager = None
132

133
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
134
135
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
136

137
138
        self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
                                       if self.model_config is not None else 0)
139

140
        self.pin_memory = is_pin_memory_available()
141
        self.kv_cache_dtype = kv_cache_dtype
142
        self.vision_language_config = vision_language_config
143

144
145
146
        self.attn_backend = get_attn_backend(
            self.model_config.dtype if model_config is not None else None)

147
148
149
150
        # Lazy initialization
        self.model: torch.nn.Module  # Set after load_model
        self.block_size: int  # Set after initial profiling.
        # When using CUDA graph, the input block tables must be padded to
151
        # max_seq_len_to_capture. However, creating the block table in
152
153
154
155
156
157
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables: torch.Tensor  # Set after initial profiling.

158
    def load_model(self) -> None:
159
        with CudaMemoryProfiler() as m:
160
            self.model = get_model(
161
162
163
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
164
165
166
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
167
168
                scheduler_config=self.scheduler_config,
            )
169
170

        self.model_memory_usage = m.consumed_memory
171
172
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
173
174

        if self.lora_config:
175
176
177
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
178
179
180
181
182
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
183
184
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
185
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
186
187
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
188
            self.model = self.lora_manager.create_lora_manager(self.model)
189

190
191
192
193
194
195
196
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
197
198
199
200
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
201
            else:
202
203
204
205
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
206
        elif self.model_config.quantization_param_path is not None:
207
208
209
            logger.warning("KV cache scaling factors provided, "
                           "but the KV cache data type is not FP8. "
                           "KV cache scaling factors will not be used.")
210

211
212
213
    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

214
        self.graph_block_tables = np.zeros(
215
216
217
218
219
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
220
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
221

222
223
224
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
225
    ) -> PreparePromptMetadata:
226
227
228
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
229
230
231
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
232

233
        seq_lens: List[int] = []
234
        context_lens: List[int] = []
235
        query_lens: List[int] = []
236
        prefix_block_tables: List[List[int]] = []
237
        multi_modal_input_list: List[torch.Tensor] = []
238

239
240
241
        if len(seq_group_metadata_list) == 0:
            return PreparePromptMetadata.empty()

242
243
244
245
246
247
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

248
249
250
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
251
252
                    and not (computed_block_nums is None
                             or computed_block_nums == [])):
253
254
255
256
257
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
258
            seq_data = seq_group_metadata.seq_data[seq_id]
259
            context_len = seq_data.get_num_computed_tokens()
260
261
            # We should use get_len here because in case of preemption
            # it contains output tokens.
262
263
264
            seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
            prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
            seq_lens.append(seq_len)
265
266
267
268
269

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
270
271
                context_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[context_len:]
272
                prefix_block_tables.append(computed_block_nums)
273
274
275
276
277
278
279
280
            elif self.scheduler_config.chunked_prefill_enabled:
                if seq_group_metadata.block_tables is not None:
                    # Prefill has chunked before.
                    block_table = seq_group_metadata.block_tables[seq_id]
                    prefix_block_tables.append(block_table)
                else:
                    # The first prefill.
                    prefix_block_tables.append([])
281
282
            else:
                prefix_block_tables.append([])
283
284
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
285
                assert context_len == 0
286

287
            # actual prompt lens
288
289
            context_lens.append(context_len)
            query_lens.append(seq_len - context_len)
290

291
            input_tokens.extend(prompt_tokens)
292
293
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
294
            input_positions.extend(list(range(context_len, seq_len)))
295
296
297
298
299
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

300
            lora_index_mapping += [lora_id] * (seq_len - context_len)
301
302
            lora_prompt_mapping.extend(
                [lora_id] *
303
                (seq_len - context_len
304
305
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

306
307
308
309
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

310
311
312
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
313
                slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
314
315
316
317
318
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
319
            # where start_idx is max(0, seq_len - sliding_window).
320
321
322
323
324
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
325
                assert context_len == 0, (
326
327
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
328
                start_idx = max(0, seq_len - self.sliding_window)
329

330
            for i in range(context_len, seq_len):
331
                if i < start_idx:
332
                    slot_mapping.append(_PAD_SLOT_ID)
333
334
335
336
337
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
338
339
                slot_mapping.append(slot)

340
341
342
        max_query_len = max(query_lens)
        max_seq_len = max(seq_lens)
        assert max_query_len > 0
343

344
345
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
346
                                           device=self.device)
347
348
349
350
351
352
353
354
355
356

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

357
358
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
359
        block_tables = make_tensor_with_pad(
360
361
362
363
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
364
            device=self.device,
365
        )
366
367
368

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
369
370
371
372
        query_lens_tensor = torch.tensor(query_lens,
                                         dtype=torch.long,
                                         device=self.device)
        subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
373
374
375
                                         dtype=torch.int32,
                                         device=self.device)

376
377
378
379
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
380
381
382
                                    dtype=torch.int32,
                                    device=self.device)

383
        torch.cumsum(query_lens_tensor,
384
385
386
387
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

388
        torch.cumsum(seq_lens_tensor,
389
390
391
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
392

393
        attn_metadata = self.attn_backend.make_metadata(
394
            is_prompt=True,
395
396
397
398
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
            max_seq_len=max_seq_len,
399
400
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
401
            context_lens_tensor=context_lens_tensor,
402
            block_tables=block_tables,
403
            use_cuda_graph=False,
404
        )
405
406
407
408
409

        return PreparePromptMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
410
411
            seq_lens=seq_lens,
            query_lens=query_lens,
412
413
414
415
416
417
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            multi_modal_input=multi_modal_input,
            slot_mapping=slot_mapping,
        )
418
419
420
421

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
422
    ) -> PrepareDecodeMetadata:
423
424
425
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
426
        seq_lens: List[int] = []
427
        block_tables: List[List[int]] = []
428
429
430
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
431

432
433
434
        if len(seq_group_metadata_list) == 0:
            return PrepareDecodeMetadata.empty()

435
436
        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
437
            assert seq_group_metadata.token_chunk_size == 1
438
439

            seq_ids = list(seq_group_metadata.seq_data.keys())
440
441
442
443
444
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

445
446
447
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
448
                input_tokens.append(generation_token)
449

450
451
                seq_len = seq_data.get_len()
                position = seq_len - 1
452
                input_positions.append(position)
453

454
                seq_len = seq_len if self.sliding_window is None else min(
455
                    seq_len, self.sliding_window)
456
                seq_lens.append(seq_len)
457

458
459
460
461
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
462
463
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
464
                lora_prompt_mapping.append(lora_id)
465
466
467
468
469
470
471

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

472
473
474
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
475
        batch_size = len(input_tokens)
476
477
478
479
        max_seq_len = max(seq_lens)
        use_captured_graph = (not self.model_config.enforce_eager
                              and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
                              and max_seq_len <= self.max_seq_len_to_capture)
480
481
482
483
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
484
485
486
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
487
                seq_lens.append(1)
488
                block_tables.append([])
489
                lora_index_mapping.append(0)
490
491
            batch_size = graph_batch_size

492
493
494
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
495
496

        if use_captured_graph:
497
498
            # When using cuda-graph all these tensors should be
            # padded.
499
500
501
            assert seq_lens_tensor.shape[0] == len(input_tokens)
            assert seq_lens_tensor.shape[0] == len(input_positions)
            assert seq_lens_tensor.shape[0] == len(slot_mapping)
502

503
504
505
506
507
508
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
509
            block_tables = torch.tensor(input_block_tables, device=self.device)
510
        else:
511
512
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
513
            block_tables = make_tensor_with_pad(
514
                block_tables,
515
                max_len=max_block_table_len,
516
517
                pad=0,
                dtype=torch.int,
518
                device=self.device,
519
            )
520

521
        attn_metadata = self.attn_backend.make_metadata(
522
            is_prompt=False,
523
524
525
526
            seq_lens=None,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=None,
            max_seq_len=max_seq_len,
527
528
            subquery_start_loc=None,
            seq_start_loc=None,
529
            context_lens_tensor=None,
530
            block_tables=block_tables,
531
            use_cuda_graph=use_captured_graph,
532
        )
533
534
535
536
537
538
539
540
541
        return PrepareDecodeMetadata(
            input_tokens=input_tokens,
            input_positions=input_positions,
            attn_metadata=attn_metadata,
            lora_index_mapping=lora_index_mapping,
            lora_prompt_mapping=lora_prompt_mapping,
            lora_requests=lora_requests,
            slot_mapping=slot_mapping,
        )
542

543
544
    def prepare_input_tensors(
        self,
545
        seq_group_metadata_list: List[SequenceGroupMetadata],
546
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
547
               Set[LoRARequest], LoRAMapping, torch.Tensor]:
548
        if self.is_driver_worker:
549
550
551
552
553
554
555
556
            prefill_reqs = []
            decode_reqs = []
            for seq_group_meta in seq_group_metadata_list:
                if seq_group_meta.is_prompt:
                    prefill_reqs.append(seq_group_meta)
                else:
                    decode_reqs.append(seq_group_meta)

557
            # Prepare input tensors.
558
559
560
561
            (
                input_tokens,
                input_positions,
                prefill_attn_metadata,
562
563
                seq_lens,
                query_lens,
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
                lora_index_mapping,
                lora_prompt_mapping,
                lora_requests,
                multi_modal_input,
                slot_mapping,
            ) = self._prepare_prompt(prefill_reqs)
            (
                decode_input_tokens,
                decode_input_positions,
                decode_attn_metadata,
                decode_lora_index_mapping,
                decode_lora_prompt_mapping,
                decode_lora_requests,
                decode_slot_mapping,
            ) = self._prepare_decode(decode_reqs)
579
            sampling_metadata = SamplingMetadata.prepare(
580
581
                seq_group_metadata_list, seq_lens, query_lens, self.device,
                self.pin_memory)
582

583
584
585
            if not self.scheduler_config.chunked_prefill_enabled:
                assert (len(prefill_reqs) and len(decode_reqs)) == 0

586
            num_prefills = len(seq_lens)
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
            num_prefill_tokens = len(input_tokens)
            num_decode_tokens = len(decode_input_tokens)

            # Coalesce tensors. Note that attn_metadata is currently not
            # coalesced for simplicity.
            input_tokens.extend(decode_input_tokens)
            input_positions.extend(decode_input_positions)
            slot_mapping.extend(decode_slot_mapping)
            lora_index_mapping.extend(decode_lora_index_mapping)
            lora_prompt_mapping.extend(decode_lora_prompt_mapping)
            lora_requests.update(decode_lora_requests)

            input_tokens = torch.tensor(input_tokens,
                                        dtype=torch.long,
                                        device=self.device)
            input_positions = torch.tensor(input_positions,
                                           dtype=torch.long,
                                           device=self.device)
            slot_mapping = torch.tensor(slot_mapping,
                                        dtype=torch.long,
                                        device=self.device)

609
610
            if self.lora_config:
                lora_mapping = LoRAMapping(
611
                    lora_index_mapping,
612
613
614
615
616
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

617
            # Broadcast the metadata.
618
619
620
621
622
623
624
625
626
627
            # If batch contains both prefill and decode, it sends 2 broadcasts.
            # If it only contains 1 type, it triggers a single broadcast.
            if (prefill_attn_metadata is not None
                    and decode_attn_metadata is not None):
                batch_type = BatchType.MIXED
            elif prefill_attn_metadata is not None:
                batch_type = BatchType.PREFILL
            else:
                batch_type = BatchType.DECODE

628
629
630
631
632
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
633
634
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
635
                "multi_modal_input": multi_modal_input,
636
637
638
639
640
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
                "batch_type": batch_type,
641
            }
642
643
644
            if prefill_attn_metadata is not None:
                metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
            else:
645
                assert decode_attn_metadata is not None
646
                metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
647
            broadcast_tensor_dict(metadata_dict, src=0)
648
649
650
651
652
653
654
655

            # Broadcast decode attn metadata for mixed batch type.
            # The additional broadcast costs 300us overhead on 4 A10 GPUs.
            # We can potentially reduce the overhead by coelescing tensors.
            if batch_type == BatchType.MIXED:
                assert decode_attn_metadata is not None
                metadata_dict = decode_attn_metadata.asdict_zerocopy()
                broadcast_tensor_dict(metadata_dict, src=0)
656
        else:
657
            metadata_dict = broadcast_tensor_dict(src=0)
658
659
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
660
661
            slot_mapping = metadata_dict.pop("slot_mapping")
            num_prefills = metadata_dict.pop("num_prefills")
662
663
664
665
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
666
            multi_modal_input = metadata_dict.pop("multi_modal_input")
667
668
669
670
671
672
673
674
675
676
677
678
679
            num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
            num_decode_tokens = metadata_dict.pop("num_decode_tokens")
            batch_type = metadata_dict.pop("batch_type")

            # Create an attention metadata.
            prefill_attn_metadata = None
            decode_attn_metadata = None
            if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
                prefill_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
            else:
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)
680
681
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
682
                selected_token_indices=selected_token_indices,
683
                categorized_sample_indices=None,
684
                num_prompts=0,
685
686
            )

687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
            # if it is a mixed batch, decode attn_metadata is broadcasted
            # separately.
            if batch_type == BatchType.MIXED:
                metadata_dict = broadcast_tensor_dict(src=0)
                decode_attn_metadata = self.attn_backend.make_metadata(
                    **metadata_dict)

        attn_metadata = AttentionMetadata(
            num_prefills=num_prefills,
            slot_mapping=slot_mapping,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            prefill_metadata=prefill_attn_metadata,
            decode_metadata=decode_attn_metadata,
            kv_cache_dtype=self.kv_cache_dtype,
        )

704
        return (input_tokens, input_positions, attn_metadata,
705
706
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
707

708
709
710
    @torch.inference_mode()
    def execute_model(
        self,
711
        seq_group_metadata_list: List[SequenceGroupMetadata],
712
        kv_caches: List[torch.Tensor],
713
    ) -> Optional[SamplerOutput]:
714
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
715
716
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
717
718
719
720

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

721
722
723
724
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
725
726
727
728
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
729
730
731
732
733
734
735
736
737
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
738

739
740
741
742
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
743
        if not self.is_driver_worker:
744
745
            return None

746
747
        # Sample the next token.
        output = self.model.sample(
748
            logits=logits,
749
750
            sampling_metadata=sampling_metadata,
        )
751

752
753
754
755
756
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
757
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
758
759
760
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

783
784
785
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
786
787
788
789
790
791
792
793
794
795
796
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
797
798
799
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
800
801
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
802
803
804
805
806
807
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
808
809
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
810
                multi_modal_data=fake_multi_modal_input,
811
812
813
814
815
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
816
        kv_caches = [None] * num_layers
817
        self.execute_model(seqs, kv_caches)
818
        torch.cuda.synchronize()
819
820
        return

821
    def remove_all_loras(self):
822
823
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
824
        self.lora_manager.remove_all_loras()
825

826
    def set_active_loras(self, lora_requests: Set[LoRARequest],
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

847
    @torch.inference_mode()
848
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
849
850
851
852
853
854
855
856
857
858
859
860
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
861
862
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
863
        self.pynccl_backend = pynccl_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
864

865
866
867
868
869
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
870
871
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
872
873
874
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
875
876
877
878
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
879
880
881
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
882
        slot_mapping.fill_(_PAD_SLOT_ID)
883
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
884
885
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

886
887
888
889
890
891
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
892
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
893
894
        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
Woosuk Kwon's avatar
Woosuk Kwon committed
895
896
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
897
        # to PyTorch or pynccl if it is disabled or not supported.
898
        with custom_all_reduce.capture():
899
900
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
901
            for batch_size in reversed(batch_size_capture_list):
902
                # Create dummy attn_metadata.
903
                decode_metadata = self.attn_backend.make_metadata(
904
                    is_prompt=False,
905
906
907
908
                    seq_lens=None,
                    seq_lens_tensor=seq_lens[:batch_size],
                    max_query_len=None,
                    max_seq_len=self.max_seq_len_to_capture,
909
910
                    subquery_start_loc=None,
                    seq_start_loc=None,
911
                    context_lens_tensor=None,
912
913
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
914
915
916
917
918
919
920
921
                )
                attn_metadata = AttentionMetadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
                    prefill_metadata=None,
                    decode_metadata=decode_metadata,
922
                    kv_cache_dtype=self.kv_cache_dtype,
923
                )
924

925
926
927
928
929
930
931
932
933
934
935
936
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
937
                    attn_metadata,
938
                    memory_pool=self.graph_memory_pool,
939
                )
940
941
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
942
943
944
945

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
946
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
947

Woosuk Kwon's avatar
Woosuk Kwon committed
948
    def __del__(self) -> None:
949
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
950
951
952
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
953
954
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
955
        self.graph_runners.clear()
956
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
957

958
959
960
961
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

962
963
964
965
966
967
968
969

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

970
971
972
973
974
975
976
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

977
978
979
980
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
981
982
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
983
        memory_pool,
984
        **kwargs,
985
    ) -> None:
986
        assert self._graph is None
987
988
989
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
990
        with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
991
            self.model(
992
993
994
                input_ids,
                positions,
                kv_caches,
995
                attn_metadata,
996
                **kwargs,
997
998
999
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
1000
1001
1002
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
1003
1004
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool):  # noqa: SIM117
1005
            with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
1006
1007
1008
1009
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
1010
                    attn_metadata,
1011
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
1012
1013
1014
                )
        torch.cuda.synchronize()

1015
1016
1017
1018
1019
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
1020
            "slot_mapping": attn_metadata.slot_mapping,
1021
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
1022
            "block_tables": attn_metadata.decode_metadata.block_tables,
1023
1024
1025
1026
1027
1028
1029
1030
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1031
1032
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1033
        **kwargs,
1034
1035
1036
1037
1038
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1039
1040
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1041
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1042
                                                 non_blocking=True)
1043
1044
        self.input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
1045
1046
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1047
1048
1049
1050
1051
1052
1053
1054
1055
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1056

1057
@contextlib.contextmanager
1058
1059
1060
1061
def _maybe_pynccl():
    if pynccl_utils.is_initialized(
    ) and not custom_all_reduce.is_initialized():
        with with_pynccl_for_all_reduce():
1062
1063
1064
1065
1066
            yield
    else:
        yield


1067
def _get_graph_batch_size(batch_size: int) -> int:
1068
1069
1070
1071
1072
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1073
1074
1075
1076
1077
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1078
1079
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input