model_runner.py 44.4 KB
Newer Older
1
import contextlib
2
import time
3
from typing import Dict, List, Optional, Set, Tuple
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

9
from vllm.attention import AttentionMetadata, get_attn_backend
10
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
11
                         SchedulerConfig, VisionLanguageConfig)
12
from vllm.logger import init_logger
13
14
15
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
16
from vllm.model_executor import SamplingMetadata
17
from vllm.model_executor.model_loader import get_model
18
from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
19
from vllm.model_executor.parallel_utils.communication_op import (
20
    broadcast_tensor_dict)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.model_executor.parallel_utils.parallel_state import (
22
    with_pynccl_for_all_reduce)
23
from vllm.sampling_params import SamplingParams, SamplingType
24
25
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
26
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
27
28
                        is_pin_memory_available, make_tensor_with_pad,
                        maybe_expand_dim)
29
30
31
32

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
33
LORA_WARMUP_RANK = 8
34
35
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
38
39
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
40
41
42
43
44
45
46
47
48


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
49
        device_config: DeviceConfig,
50
        lora_config: Optional[LoRAConfig],
51
        kv_cache_dtype: Optional[str] = "auto",
52
        is_driver_worker: bool = False,
53
        vision_language_config: Optional[VisionLanguageConfig] = None,
54
55
56
57
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
58
        self.lora_config = lora_config
59
        self.is_driver_worker = is_driver_worker
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
64
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
65
66
67
68
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

69
70
        self.model = None
        self.block_size = None  # Set after initial profiling.
71
        self.lora_manager = None
72

73
74
75
76
77
78
79
80
81
82
83
84
85
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
86
        self.pin_memory = is_pin_memory_available()
87
        self.kv_cache_dtype = kv_cache_dtype
88
        self.vision_language_config = vision_language_config
89

90
91
92
        self.attn_backend = get_attn_backend(
            self.model_config.dtype if model_config is not None else None)

93
    def load_model(self) -> None:
94
        with CudaMemoryProfiler() as m:
95
96
97
98
99
100
101
            self.model = get_model(
                self.model_config,
                self.device_config,
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config)
102
103

        self.model_memory_usage = m.consumed_memory
104
105
        logger.info(f"Loading model weights took "
                    f"{self.model_memory_usage / float(2**30):.4f} GB")
106
107

        if self.lora_config:
108
109
110
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
111
112
113
114
115
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
116
117
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
118
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
119
120
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
121
            self.model = self.lora_manager.create_lora_manager(self.model)
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
                    raise RuntimeError("Using FP8 KV cache and scaling "
                                       "factors provided but model "
                                       f"{self.model.__class__} does not "
                                       "support loading scaling factors.")
            else:
                logger.warn("Using FP8 KV cache but no scaling factors "
                            "provided. Defaulting to scaling factors of 1.0. "
                            "This may lead to less accurate results!")
        elif self.model_config.quantization_param_path is not None:
            logger.warn("KV cache scaling factors provided, "
                        "but the KV cache data type is not FP8. "
                        "KV cache scaling factors will not be used.")

143
144
145
    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

146
        self.graph_block_tables = np.zeros(
147
148
149
150
151
152
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
        return (self.max_context_len_to_capture + block_size - 1) // block_size
153

154
155
156
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
157
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
158
159
               List[int], List[int], List[int], Set[LoRARequest],
               torch.Tensor]:
160
        assert len(seq_group_metadata_list) > 0
161
162
163
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
164
165
166
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
167
168

        prompt_lens: List[int] = []
169
170
171
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
172
        multi_modal_input_list: List[torch.Tensor] = []
173

174
175
176
177
178
179
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

180
181
182
183
184
185
186
187
188
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
                    and computed_block_nums is not None):
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
189
            seq_data = seq_group_metadata.seq_data[seq_id]
190
191
192
193
194
195
196
            computed_len = seq_data.get_num_computed_tokens()
            # We should use get_len here because in case of preemption
            # it contains output tokens.
            prefill_end = min(seq_data.get_len(),
                              computed_len + token_chunk_size)
            # TODO(sang): Rename it after chunked prefill is introduced.
            prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
197
            prompt_len = len(prompt_tokens)
198
199
200
201
            # Right now, the prefill_end is always same as the length of
            # sequence. However, once chunked prefill is introduced, this
            # assumption can be changed.
            assert prefill_end == seq_data.get_len()
202
            prompt_lens.append(prompt_len)
203
204
205
206
207
208
209
210

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
                computed_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[computed_len:]
                prefix_block_tables.append(computed_block_nums)
211
212
            else:
                prefix_block_tables.append([])
213
214
215
216
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
                assert computed_len == 0

217
            # actual prompt lens
218
            context_lens.append(computed_len)
219
            subquery_lens.append(prompt_len - computed_len)
220

221
            input_tokens.extend(prompt_tokens)
222
223
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
224
            input_positions.extend(list(range(computed_len, prefill_end)))
225
226
227
228
229
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

230
            lora_index_mapping += [lora_id] * (prompt_len - computed_len)
231
232
            lora_prompt_mapping.extend(
                [lora_id] *
233
                (prompt_len - computed_len
234
235
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

236
237
238
239
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

240
241
242
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
243
                slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
244
245
246
247
248
249
250
251
252
253
254
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
255
                assert computed_len == 0, (
256
257
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
258
                start_idx = max(0, prompt_len - self.sliding_window)
259
260

            for i in range(computed_len, prefill_end):
261
                if i < start_idx:
262
                    slot_mapping.append(_PAD_SLOT_ID)
263
264
265
266
267
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
268
269
270
                slot_mapping.append(slot)

        max_subquery_len = max(subquery_lens)
271
        max_prompt_len = max(prompt_lens)
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        num_prompt_tokens = len(input_tokens)
        assert max_subquery_len > 0

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
        lora_index_mapping = lora_index_mapping

286
287
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
288
                                           device=self.device)
289
290
291
292
293
294
295
296
297
298

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

299
300
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
301
        block_tables = make_tensor_with_pad(
302
303
304
305
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
306
            device=self.device,
307
        )
308
309
310
311
312
313
314
315
316
317

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
        subquery_lens_tensor = torch.tensor(subquery_lens,
                                            dtype=torch.long,
                                            device=self.device)
        subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
                                         dtype=torch.int32,
                                         device=self.device)

318
319
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
320
                                          device=self.device)
321
322
323
324
325
326
327
328
329
330
331
332
333
        seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(subquery_lens_tensor,
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

        torch.cumsum(prompt_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
334

335
        attn_metadata = self.attn_backend.make_metadata(
336
            is_prompt=True,
337
            slot_mapping=slot_mapping,
338
339
340
341
342
            prompt_lens=prompt_lens,
            prompt_lens_tensor=prompt_lens_tensor,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=0,
            max_subquery_len=max_subquery_len,
343
            max_context_len=None,
344
            max_prompt_len=max_prompt_len,
345
346
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
347
348
            context_lens=context_lens_tensor,
            block_tables=block_tables,
349
            use_cuda_graph=False,
350
            kv_cache_dtype=self.kv_cache_dtype,
351
        )
352
        return (input_tokens, input_positions, attn_metadata, prompt_lens,
353
                subquery_lens, lora_index_mapping, lora_prompt_mapping,
354
                lora_requests, multi_modal_input)
355
356
357
358

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
359
360
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               List[int], Set[LoRARequest]]:
361
        assert len(seq_group_metadata_list) > 0
362
363
364
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
365
366
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
367
368
369
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
370
371
372

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
373
            assert seq_group_metadata.token_chunk_size == 1
374
375

            seq_ids = list(seq_group_metadata.seq_data.keys())
376
377
378
379
380
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

381
382
383
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
384
                input_tokens.append(generation_token)
385

386
387
                seq_len = seq_data.get_len()
                position = seq_len - 1
388
                input_positions.append(position)
389

390
391
392
393
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

394
395
396
397
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
398
399
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
400
                lora_prompt_mapping.append(lora_id)
401
402
403
404
405
406
407

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

408
409
410
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
411
412
413
414
415
416
417
418
419
420
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
421
422
423
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
424
425
                context_lens.append(1)
                block_tables.append([])
426
                lora_index_mapping.append(0)
427
428
            batch_size = graph_batch_size

429
430
431
432
433
434
435
436
437
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
438
439
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
440
                                    device=self.device)
441
442

        if use_captured_graph:
443
444
445
446
447
448
            # When using cuda-graph all these tensors should be
            # padded.
            assert context_lens.shape[0] == input_tokens.shape[0]
            assert context_lens.shape[0] == input_positions.shape[0]
            assert context_lens.shape[0] == slot_mapping.shape[0]

449
450
451
452
453
454
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
455
            block_tables = torch.tensor(input_block_tables, device=self.device)
456
        else:
457
458
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
459
            block_tables = make_tensor_with_pad(
460
                block_tables,
461
                max_len=max_block_table_len,
462
463
                pad=0,
                dtype=torch.int,
464
                device=self.device,
465
            )
466

467
        attn_metadata = self.attn_backend.make_metadata(
468
            is_prompt=False,
469
            slot_mapping=slot_mapping,
470
            prompt_lens=None,
471
472
473
474
            prompt_lens_tensor=None,
            num_prompt_tokens=0,
            num_generation_tokens=len(input_tokens),
            max_subquery_len=None,
475
            max_context_len=max_context_len,
476
            max_prompt_len=None,
477
478
            subquery_start_loc=None,
            seq_start_loc=None,
479
480
            context_lens=context_lens,
            block_tables=block_tables,
481
            use_cuda_graph=use_captured_graph,
482
            kv_cache_dtype=self.kv_cache_dtype,
483
        )
484
        return (input_tokens, input_positions, attn_metadata,
485
                lora_index_mapping, lora_prompt_mapping, lora_requests)
486
487
488
489
490

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
491
        subquery_lens: Optional[List[int]],
492
493
494
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
495
        generators: List[torch.Generator] = []
496
497
498
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
499
        categorized_sampled_token_indices_start_idx = 0
500
501
502
503
504
505
506
507

        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
508
509
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
510
511
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
512
                    categorized_sample_indices_start_idx += subquery_len - 1
513
514

                categorized_sample_indices[
515
516
517
518
                    sampling_params.sampling_type].append([
                        categorized_sample_indices_start_idx,
                        categorized_sampled_token_indices_start_idx
                    ])
519
                categorized_sample_indices_start_idx += 1
520
                categorized_sampled_token_indices_start_idx += 1
521
522
523
524

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
525
                              selected_token_start_idx + subquery_len - 1))
526
                selected_token_indices.append(selected_token_start_idx +
527
                                              subquery_len - 1)
528
                selected_token_start_idx += subquery_len
Nick Hill's avatar
Nick Hill committed
529
530
531

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
532
                        device=self.device).manual_seed(sampling_params.seed)
533
534
535
536
537
538
539
540
541
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
542
543
544
545
546
547
548
549
550
                        zip(
                            range(
                                categorized_sample_indices_start_idx,
                                categorized_sample_indices_start_idx +
                                num_seqs),
                            range(
                                categorized_sampled_token_indices_start_idx,
                                categorized_sampled_token_indices_start_idx +
                                num_seqs)))
551
                categorized_sample_indices_start_idx += num_seqs
552
                categorized_sampled_token_indices_start_idx += num_seqs
553

Nick Hill's avatar
Nick Hill committed
554
555
556
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

557
558
559
560
        selected_token_indices = async_tensor_h2d(selected_token_indices,
                                                  dtype=torch.long,
                                                  target_device=self.device,
                                                  pin_memory=self.pin_memory)
561

562
        categorized_sample_indices = {
563
564
565
566
567
            t: maybe_expand_dim(
                async_tensor_h2d(seq_ids,
                                 dtype=torch.int,
                                 target_device=self.device,
                                 pin_memory=self.pin_memory), 2, 2)
568
569
570
571
572
573
574
575
576
577
578
579
580
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
581
            generators=generators,
582
583
584
        )
        return sampling_metadata

585
586
587
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
588
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
589
               Set[int], LoRAMapping, torch.Tensor]:
590
591
592
593
594
595
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
596
                (input_tokens, input_positions, attn_metadata, prompt_lens,
597
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
598
599
                 lora_requests, multi_modal_input
                 ) = self._prepare_prompt(seq_group_metadata_list)
600
            else:
601
                (input_tokens, input_positions, attn_metadata,
602
603
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
604
                prompt_lens = []
605
                subquery_lens = None
606
                multi_modal_input = None
607
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
608
609
                                                     prompt_lens,
                                                     subquery_lens)
610

611
612
            if self.lora_config:
                lora_mapping = LoRAMapping(
613
                    lora_index_mapping,
614
615
616
617
618
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

619
620
621
622
623
624
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
625
626
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
627
                "multi_modal_input": multi_modal_input,
628
            }
629
            metadata_dict.update(attn_metadata.asdict_zerocopy())
630
            broadcast_tensor_dict(metadata_dict, src=0)
631
        else:
632
            metadata_dict = broadcast_tensor_dict(src=0)
633
634
635
636
637
638
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
639
            multi_modal_input = metadata_dict.pop("multi_modal_input")
640
            attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
641
642
643
644
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
645
                selected_token_indices=selected_token_indices,
646
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
647
                generators=None,
648
649
650
                perform_sampling=False,
            )

651
        return (input_tokens, input_positions, attn_metadata,
652
653
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
654

655
656
657
    @torch.inference_mode()
    def execute_model(
        self,
658
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
659
        kv_caches: List[torch.Tensor],
660
    ) -> Optional[SamplerOutput]:
661
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
662
663
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
664
665
666
667

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

668
        # Execute the model.
669
        if attn_metadata.use_cuda_graph:
670
671
672
673
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
674
675
676
677
678
679
680
681
682
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
683

684
685
686
687
688
689
690
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

691
692
        # Sample the next token.
        output = self.model.sample(
693
            logits=logits,
694
695
696
697
698
699
700
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
701
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
702
703
704
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

727
728
729
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
730
731
732
733
734
735
736
737
738
739
740
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
741
742
743
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
744
745
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
746
747
748
749
750
751
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
752
753
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
754
                multi_modal_data=fake_multi_modal_input,
755
756
757
758
759
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
760
        kv_caches = [None] * num_layers
761
        self.execute_model(seqs, kv_caches)
762
        torch.cuda.synchronize()
763
764
        return

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

791
    @torch.inference_mode()
792
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
793
794
795
796
797
798
799
800
801
802
803
804
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
805
806
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
807
        self.pynccl_backend = pynccl_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
808

809
810
811
812
813
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
814
815
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
816
817
818
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
819
820
821
822
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
823
824
825
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
826
827
828
829
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

830
831
832
833
834
835
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
836
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
837
838
        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
Woosuk Kwon's avatar
Woosuk Kwon committed
839
840
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
841
        # to PyTorch or pynccl if it is disabled or not supported.
842
        with custom_all_reduce.capture():
843
844
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
845
            for batch_size in reversed(batch_size_capture_list):
846
847
                # Create dummy attn_metadata.
                attn_metadata = self.attn_backend.make_metadata(
848
849
850
                    is_prompt=False,
                    slot_mapping=slot_mapping[:batch_size],
                    prompt_lens=None,
851
852
853
854
                    prompt_lens_tensor=None,
                    num_prompt_tokens=0,
                    num_generation_tokens=batch_size,
                    max_subquery_len=None,
855
                    max_context_len=self.max_context_len_to_capture,
856
                    max_prompt_len=None,
857
858
                    subquery_start_loc=None,
                    seq_start_loc=None,
859
860
861
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
862
                    kv_cache_dtype=self.kv_cache_dtype,
863
                )
864

865
866
867
868
869
870
871
872
873
874
875
876
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
877
                    attn_metadata,
878
                    memory_pool=self.graph_memory_pool,
879
                )
880
881
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
882
883
884
885
886
887

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
888
    def __del__(self) -> None:
889
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
890
891
892
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
893
894
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
895
        self.graph_runners.clear()
896
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
897

898
899
900
901
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

902
903
904
905
906
907
908
909
910
911
912
913
914

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
915
916
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
917
        memory_pool,
918
        **kwargs,
919
920
921
922
923
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
924
        with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
925
            self.model(
926
927
928
                input_ids,
                positions,
                kv_caches,
929
                attn_metadata,
930
                **kwargs,
931
932
933
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
934
935
936
937
938
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
939
            with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
940
941
942
943
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
944
                    attn_metadata,
945
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
946
947
948
                )
        torch.cuda.synchronize()

949
950
951
952
953
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
954
955
956
            "slot_mapping": attn_metadata.slot_mapping,
            "context_lens": attn_metadata.context_lens,
            "block_tables": attn_metadata.block_tables,
957
958
959
960
961
962
963
964
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
965
966
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
967
        **kwargs,
968
969
970
971
972
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
973
974
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
975
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
976
                                                 non_blocking=True)
977
        self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
978
                                                 non_blocking=True)
979
        self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
980
                                                 non_blocking=True)
981
982
983
984
985
986
987
988
989
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

990

991
@contextlib.contextmanager
992
993
994
995
def _maybe_pynccl():
    if pynccl_utils.is_initialized(
    ) and not custom_all_reduce.is_initialized():
        with with_pynccl_for_all_reduce():
996
997
998
999
1000
            yield
    else:
        yield


1001
def _get_graph_batch_size(batch_size: int) -> int:
1002
1003
1004
1005
1006
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1007
1008
1009
1010
1011
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1012
1013
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input