model_runner.py 43.2 KB
Newer Older
1
import contextlib
2
import time
3
from typing import Dict, List, Optional, Set, Tuple
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

9
from vllm.attention import AttentionMetadata, get_attn_backend
10
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
11
                         SchedulerConfig, VisionLanguageConfig)
12
from vllm.logger import init_logger
13
14
15
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
16
from vllm.model_executor import SamplingMetadata
17
from vllm.model_executor.model_loader import get_model
18
from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
19
from vllm.model_executor.parallel_utils.communication_op import (
20
    broadcast_tensor_dict)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.model_executor.parallel_utils.parallel_state import (
22
    with_pynccl_for_all_reduce)
23
from vllm.sampling_params import SamplingParams, SamplingType
24
25
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
26
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d,
27
28
                        is_pin_memory_available, make_tensor_with_pad,
                        maybe_expand_dim)
29
30
31
32

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
33
LORA_WARMUP_RANK = 8
34
35
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
38
39
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
40
41
42
43
44
45
46
47
48


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
49
        device_config: DeviceConfig,
50
        lora_config: Optional[LoRAConfig],
51
        kv_cache_dtype: Optional[str] = "auto",
52
        is_driver_worker: bool = False,
53
        vision_language_config: Optional[VisionLanguageConfig] = None,
54
55
56
57
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
58
        self.lora_config = lora_config
59
        self.is_driver_worker = is_driver_worker
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
64
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
65
66
67
68
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

69
70
        self.model = None
        self.block_size = None  # Set after initial profiling.
71
        self.lora_manager = None
72

73
74
75
76
77
78
79
80
81
82
83
84
85
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
86
        self.pin_memory = is_pin_memory_available()
87
        self.kv_cache_dtype = kv_cache_dtype
88
        self.vision_language_config = vision_language_config
89

90
91
92
        self.attn_backend = get_attn_backend(
            self.model_config.dtype if model_config is not None else None)

93
    def load_model(self) -> None:
94
        with CudaMemoryProfiler() as m:
95
96
97
98
99
100
101
            self.model = get_model(
                self.model_config,
                self.device_config,
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config)
102
103

        self.model_memory_usage = m.consumed_memory
104
105
        logger.info(f"Loading model weights took "
                    f"{self.model_memory_usage / float(2**30):.4f} GB")
106
107

        if self.lora_config:
108
109
110
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
111
112
113
114
115
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
116
117
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
118
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
119
120
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
121
            self.model = self.lora_manager.create_lora_manager(self.model)
122
123
124
125

    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

126
        self.graph_block_tables = np.zeros(
127
128
129
130
131
132
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
        return (self.max_context_len_to_capture + block_size - 1) // block_size
133

134
135
136
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
137
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
138
139
               List[int], List[int], List[int], Set[LoRARequest],
               torch.Tensor]:
140
        assert len(seq_group_metadata_list) > 0
141
142
143
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
144
145
146
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
147
148

        prompt_lens: List[int] = []
149
150
151
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
152
        multi_modal_input_list: List[torch.Tensor] = []
153

154
155
156
157
158
159
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

160
161
162
163
164
165
166
167
168
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
                    and computed_block_nums is not None):
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
169
            seq_data = seq_group_metadata.seq_data[seq_id]
170
171
172
173
174
175
176
            computed_len = seq_data.get_num_computed_tokens()
            # We should use get_len here because in case of preemption
            # it contains output tokens.
            prefill_end = min(seq_data.get_len(),
                              computed_len + token_chunk_size)
            # TODO(sang): Rename it after chunked prefill is introduced.
            prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
177
            prompt_len = len(prompt_tokens)
178
179
180
181
            # Right now, the prefill_end is always same as the length of
            # sequence. However, once chunked prefill is introduced, this
            # assumption can be changed.
            assert prefill_end == seq_data.get_len()
182
            prompt_lens.append(prompt_len)
183
184
185
186
187
188
189
190

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
                computed_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[computed_len:]
                prefix_block_tables.append(computed_block_nums)
191
192
            else:
                prefix_block_tables.append([])
193
194
195
196
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
                assert computed_len == 0

197
            # actual prompt lens
198
            context_lens.append(computed_len)
199
            subquery_lens.append(prompt_len - computed_len)
200

201
            input_tokens.extend(prompt_tokens)
202
203
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
204
            input_positions.extend(list(range(computed_len, prefill_end)))
205

206
207
208
209
210
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

211
            lora_index_mapping += [lora_id] * (prompt_len - computed_len)
212
213
            lora_prompt_mapping.extend(
                [lora_id] *
214
                (prompt_len - computed_len
215
216
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

217
218
219
220
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

221
222
223
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
224
                slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
225
226
227
228
229
230
231
232
233
234
235
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
236
                assert computed_len == 0, (
237
238
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
239
                start_idx = max(0, prompt_len - self.sliding_window)
240
241

            for i in range(computed_len, prefill_end):
242
                if i < start_idx:
243
                    slot_mapping.append(_PAD_SLOT_ID)
244
245
246
247
248
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
249
250
251
                slot_mapping.append(slot)

        max_subquery_len = max(subquery_lens)
252
        max_prompt_len = max(prompt_lens)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        num_prompt_tokens = len(input_tokens)
        assert max_subquery_len > 0

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
        lora_index_mapping = lora_index_mapping

267
268
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
269
                                           device=self.device)
270
271
272
273
274
275
276
277
278
279

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

280
281
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
282
        block_tables = make_tensor_with_pad(
283
284
285
286
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
287
            device=self.device,
288
        )
289
290
291
292
293
294
295
296
297
298

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
        subquery_lens_tensor = torch.tensor(subquery_lens,
                                            dtype=torch.long,
                                            device=self.device)
        subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
                                         dtype=torch.int32,
                                         device=self.device)

299
300
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
301
                                          device=self.device)
302
303
304
305
306
307
308
309
310
311
312
313
314
        seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(subquery_lens_tensor,
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

        torch.cumsum(prompt_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
315

316
        attn_metadata = self.attn_backend.make_metadata(
317
            is_prompt=True,
318
            slot_mapping=slot_mapping,
319
320
321
322
323
            prompt_lens=prompt_lens,
            prompt_lens_tensor=prompt_lens_tensor,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=0,
            max_subquery_len=max_subquery_len,
324
            max_context_len=None,
325
            max_prompt_len=max_prompt_len,
326
327
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
328
329
            context_lens=context_lens_tensor,
            block_tables=block_tables,
330
            use_cuda_graph=False,
331
            kv_cache_dtype=self.kv_cache_dtype,
332
        )
333
        return (input_tokens, input_positions, attn_metadata, prompt_lens,
334
                subquery_lens, lora_index_mapping, lora_prompt_mapping,
335
                lora_requests, multi_modal_input)
336
337
338
339

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
340
341
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               List[int], Set[LoRARequest]]:
342
        assert len(seq_group_metadata_list) > 0
343
344
345
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
346
347
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
348
349
350
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
351
352
353

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
354
            assert seq_group_metadata.token_chunk_size == 1
355
356

            seq_ids = list(seq_group_metadata.seq_data.keys())
357
358
359
360
361
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

362
363
364
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
365
                input_tokens.append(generation_token)
366

367
368
                seq_len = seq_data.get_len()
                position = seq_len - 1
369
                input_positions.append(position)
370

371
372
373
374
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

375
376
377
378
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
379
380
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
381
                lora_prompt_mapping.append(lora_id)
382
383
384
385
386
387
388

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

389
390
391
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
392
393
394
395
396
397
398
399
400
401
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
402
403
404
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
405
406
                context_lens.append(1)
                block_tables.append([])
407
                lora_index_mapping.append(0)
408
409
            batch_size = graph_batch_size

410
411
412
413
414
415
416
417
418
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
419
420
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
421
                                    device=self.device)
422
423

        if use_captured_graph:
424
425
426
427
428
429
            # When using cuda-graph all these tensors should be
            # padded.
            assert context_lens.shape[0] == input_tokens.shape[0]
            assert context_lens.shape[0] == input_positions.shape[0]
            assert context_lens.shape[0] == slot_mapping.shape[0]

430
431
432
433
434
435
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
436
            block_tables = torch.tensor(input_block_tables, device=self.device)
437
        else:
438
439
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
440
            block_tables = make_tensor_with_pad(
441
                block_tables,
442
                max_len=max_block_table_len,
443
444
                pad=0,
                dtype=torch.int,
445
                device=self.device,
446
            )
447

448
        attn_metadata = self.attn_backend.make_metadata(
449
            is_prompt=False,
450
            slot_mapping=slot_mapping,
451
            prompt_lens=None,
452
453
454
455
            prompt_lens_tensor=None,
            num_prompt_tokens=0,
            num_generation_tokens=len(input_tokens),
            max_subquery_len=None,
456
            max_context_len=max_context_len,
457
            max_prompt_len=None,
458
459
            subquery_start_loc=None,
            seq_start_loc=None,
460
461
            context_lens=context_lens,
            block_tables=block_tables,
462
            use_cuda_graph=use_captured_graph,
463
            kv_cache_dtype=self.kv_cache_dtype,
464
        )
465
        return (input_tokens, input_positions, attn_metadata,
466
                lora_index_mapping, lora_prompt_mapping, lora_requests)
467
468
469
470
471

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
472
        subquery_lens: Optional[List[int]],
473
474
475
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
476
        generators: List[torch.Generator] = []
477
478
479
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
480
        categorized_sampled_token_indices_start_idx = 0
481
482
483
484
485
486
487
488

        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
489
490
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
491
492
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
493
                    categorized_sample_indices_start_idx += subquery_len - 1
494
495

                categorized_sample_indices[
496
497
498
499
                    sampling_params.sampling_type].append([
                        categorized_sample_indices_start_idx,
                        categorized_sampled_token_indices_start_idx
                    ])
500
                categorized_sample_indices_start_idx += 1
501
                categorized_sampled_token_indices_start_idx += 1
502
503
504
505

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
506
                              selected_token_start_idx + subquery_len - 1))
507
                selected_token_indices.append(selected_token_start_idx +
508
                                              subquery_len - 1)
509
                selected_token_start_idx += subquery_len
Nick Hill's avatar
Nick Hill committed
510
511
512

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
513
                        device=self.device).manual_seed(sampling_params.seed)
514
515
516
517
518
519
520
521
522
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
523
524
525
526
527
528
529
530
531
                        zip(
                            range(
                                categorized_sample_indices_start_idx,
                                categorized_sample_indices_start_idx +
                                num_seqs),
                            range(
                                categorized_sampled_token_indices_start_idx,
                                categorized_sampled_token_indices_start_idx +
                                num_seqs)))
532
                categorized_sample_indices_start_idx += num_seqs
533
                categorized_sampled_token_indices_start_idx += num_seqs
534

Nick Hill's avatar
Nick Hill committed
535
536
537
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

538
539
540
541
        selected_token_indices = async_tensor_h2d(selected_token_indices,
                                                  dtype=torch.long,
                                                  target_device=self.device,
                                                  pin_memory=self.pin_memory)
542

543
        categorized_sample_indices = {
544
545
546
547
548
            t: maybe_expand_dim(
                async_tensor_h2d(seq_ids,
                                 dtype=torch.int,
                                 target_device=self.device,
                                 pin_memory=self.pin_memory), 2, 2)
549
550
551
552
553
554
555
556
557
558
559
560
561
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
562
            generators=generators,
563
564
565
        )
        return sampling_metadata

566
567
568
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
569
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
570
               Set[int], LoRAMapping, torch.Tensor]:
571
572
573
574
575
576
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
577
                (input_tokens, input_positions, attn_metadata, prompt_lens,
578
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
579
580
                 lora_requests, multi_modal_input
                 ) = self._prepare_prompt(seq_group_metadata_list)
581
            else:
582
                (input_tokens, input_positions, attn_metadata,
583
584
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
585
                prompt_lens = []
586
                subquery_lens = None
587
                multi_modal_input = None
588
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
589
590
                                                     prompt_lens,
                                                     subquery_lens)
591

592
593
            if self.lora_config:
                lora_mapping = LoRAMapping(
594
                    lora_index_mapping,
595
596
597
598
599
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

600
601
602
603
604
605
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
606
607
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
608
                "multi_modal_input": multi_modal_input,
609
            }
610
            metadata_dict.update(attn_metadata.asdict_zerocopy())
611
            broadcast_tensor_dict(metadata_dict, src=0)
612
        else:
613
            metadata_dict = broadcast_tensor_dict(src=0)
614
615
616
617
618
619
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
620
            multi_modal_input = metadata_dict.pop("multi_modal_input")
621
            attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
622
623
624
625
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
626
                selected_token_indices=selected_token_indices,
627
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
628
                generators=None,
629
630
631
                perform_sampling=False,
            )

632
        return (input_tokens, input_positions, attn_metadata,
633
634
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
635

636
637
638
    @torch.inference_mode()
    def execute_model(
        self,
639
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
640
        kv_caches: List[torch.Tensor],
641
    ) -> Optional[SamplerOutput]:
642
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
643
644
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
645
646
647
648

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

649
        # Execute the model.
650
        if attn_metadata.use_cuda_graph:
651
652
653
654
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
655
656
657
658
659
660
661
662
663
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
664

665
666
667
668
669
670
671
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

672
673
        # Sample the next token.
        output = self.model.sample(
674
            logits=logits,
675
676
677
678
679
680
681
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
682
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
683
684
685
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

708
709
710
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
711
712
713
714
715
716
717
718
719
720
721
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
722
723
724
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
725
726
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
727
728
729
730
731
732
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
733
734
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
735
                multi_modal_data=fake_multi_modal_input,
736
737
738
739
740
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
741
        kv_caches = [None] * num_layers
742
        self.execute_model(seqs, kv_caches)
743
        torch.cuda.synchronize()
744
745
        return

746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

772
    @torch.inference_mode()
773
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
774
775
776
777
778
779
780
781
782
783
784
785
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
786
787
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
788
        self.pynccl_backend = pynccl_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
789

790
791
792
793
794
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
795
796
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
797
798
799
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
800
801
802
803
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
804
805
806
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
807
808
809
810
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

811
812
813
814
815
816
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
817
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
818
819
        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
Woosuk Kwon's avatar
Woosuk Kwon committed
820
821
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
822
        # to PyTorch or pynccl if it is disabled or not supported.
823
        with custom_all_reduce.capture():
824
825
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
826
            for batch_size in reversed(batch_size_capture_list):
827
828
                # Create dummy attn_metadata.
                attn_metadata = self.attn_backend.make_metadata(
829
830
831
                    is_prompt=False,
                    slot_mapping=slot_mapping[:batch_size],
                    prompt_lens=None,
832
833
834
835
                    prompt_lens_tensor=None,
                    num_prompt_tokens=0,
                    num_generation_tokens=batch_size,
                    max_subquery_len=None,
836
                    max_context_len=self.max_context_len_to_capture,
837
                    max_prompt_len=None,
838
839
                    subquery_start_loc=None,
                    seq_start_loc=None,
840
841
842
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
843
                    kv_cache_dtype=self.kv_cache_dtype,
844
                )
845

846
847
848
849
850
851
852
853
854
855
856
857
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
858
                    attn_metadata,
859
                    memory_pool=self.graph_memory_pool,
860
                )
861
862
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
863
864
865
866
867
868

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
869
    def __del__(self) -> None:
870
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
871
872
873
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
874
875
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
876
        self.graph_runners.clear()
877
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
878

879
880
881
882
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

883
884
885
886
887
888
889
890
891
892
893
894
895

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
896
897
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
898
        memory_pool,
899
        **kwargs,
900
901
902
903
904
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
905
        with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
906
            self.model(
907
908
909
                input_ids,
                positions,
                kv_caches,
910
                attn_metadata,
911
                **kwargs,
912
913
914
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
915
916
917
918
919
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
920
            with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
921
922
923
924
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
925
                    attn_metadata,
926
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
927
928
929
                )
        torch.cuda.synchronize()

930
931
932
933
934
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
935
936
937
            "slot_mapping": attn_metadata.slot_mapping,
            "context_lens": attn_metadata.context_lens,
            "block_tables": attn_metadata.block_tables,
938
939
940
941
942
943
944
945
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
946
947
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
948
        **kwargs,
949
950
951
952
953
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
954
955
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
956
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
957
                                                 non_blocking=True)
958
        self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
959
                                                 non_blocking=True)
960
        self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
961
                                                 non_blocking=True)
962
963
964
965
966
967
968
969
970
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

971

972
@contextlib.contextmanager
973
974
975
976
def _maybe_pynccl():
    if pynccl_utils.is_initialized(
    ) and not custom_all_reduce.is_initialized():
        with with_pynccl_for_all_reduce():
977
978
979
980
981
            yield
    else:
        yield


982
def _get_graph_batch_size(batch_size: int) -> int:
983
984
985
986
987
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
988
989
990
991
992
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
993
994
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input