model_runner.py 44.4 KB
Newer Older
1
import contextlib
2
import time
3
from typing import Dict, List, Optional, Set, Tuple
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

9
from vllm.attention import AttentionMetadata, get_attn_backend
10
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
11
                         SchedulerConfig, VisionLanguageConfig)
12
from vllm.logger import init_logger
13
14
15
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
16
from vllm.model_executor import SamplingMetadata
17
from vllm.model_executor.model_loader import get_model
18
from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
19
from vllm.model_executor.parallel_utils.communication_op import (
20
    broadcast_tensor_dict)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.model_executor.parallel_utils.parallel_state import (
22
    with_pynccl_for_all_reduce)
23
from vllm.sampling_params import SamplingParams, SamplingType
24
25
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
26
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
27
28
                        is_pin_memory_available, make_tensor_with_pad,
                        maybe_expand_dim)
29
30
31
32

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
33
LORA_WARMUP_RANK = 8
34
35
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
38
39
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
40
41
42
43
44
45
46
47
48


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
49
        device_config: DeviceConfig,
50
        lora_config: Optional[LoRAConfig],
51
        kv_cache_dtype: Optional[str] = "auto",
52
        is_driver_worker: bool = False,
53
        vision_language_config: Optional[VisionLanguageConfig] = None,
54
55
56
57
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
58
        self.lora_config = lora_config
59
        self.is_driver_worker = is_driver_worker
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
64
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
65
66
67
68
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

69
70
        self.model = None
        self.block_size = None  # Set after initial profiling.
71
        self.lora_manager = None
72

73
74
75
76
77
78
79
80
81
82
83
84
85
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
86
        self.pin_memory = is_pin_memory_available()
87
        self.kv_cache_dtype = kv_cache_dtype
88
        self.vision_language_config = vision_language_config
89

90
91
92
        self.attn_backend = get_attn_backend(
            self.model_config.dtype if model_config is not None else None)

93
    def load_model(self) -> None:
94
        with CudaMemoryProfiler() as m:
95
96
97
98
99
100
101
            self.model = get_model(
                self.model_config,
                self.device_config,
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config)
102
103

        self.model_memory_usage = m.consumed_memory
104
105
        logger.info(f"Loading model weights took "
                    f"{self.model_memory_usage / float(2**30):.4f} GB")
106
107

        if self.lora_config:
108
109
110
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
111
112
113
114
115
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
116
117
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
118
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
119
120
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
121
            self.model = self.lora_manager.create_lora_manager(self.model)
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
                    raise RuntimeError("Using FP8 KV cache and scaling "
                                       "factors provided but model "
                                       f"{self.model.__class__} does not "
                                       "support loading scaling factors.")
            else:
                logger.warn("Using FP8 KV cache but no scaling factors "
                            "provided. Defaulting to scaling factors of 1.0. "
                            "This may lead to less accurate results!")
        elif self.model_config.quantization_param_path is not None:
            logger.warn("KV cache scaling factors provided, "
                        "but the KV cache data type is not FP8. "
                        "KV cache scaling factors will not be used.")

143
144
145
    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

146
        self.graph_block_tables = np.zeros(
147
148
149
150
151
152
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
        return (self.max_context_len_to_capture + block_size - 1) // block_size
153

154
155
156
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
157
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
158
159
               List[int], List[int], List[int], Set[LoRARequest],
               torch.Tensor]:
160
        assert len(seq_group_metadata_list) > 0
161
162
163
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
164
165
166
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
167
168

        prompt_lens: List[int] = []
169
170
171
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
172
        multi_modal_input_list: List[torch.Tensor] = []
173

174
175
176
177
178
179
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

180
181
182
183
184
185
186
187
188
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
                    and computed_block_nums is not None):
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
189
            seq_data = seq_group_metadata.seq_data[seq_id]
190
191
192
193
194
195
196
            computed_len = seq_data.get_num_computed_tokens()
            # We should use get_len here because in case of preemption
            # it contains output tokens.
            prefill_end = min(seq_data.get_len(),
                              computed_len + token_chunk_size)
            # TODO(sang): Rename it after chunked prefill is introduced.
            prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
197
            prompt_len = len(prompt_tokens)
198
199
200
201
            # Right now, the prefill_end is always same as the length of
            # sequence. However, once chunked prefill is introduced, this
            # assumption can be changed.
            assert prefill_end == seq_data.get_len()
202
            prompt_lens.append(prompt_len)
203
204
205
206
207
208
209
210

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
                computed_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[computed_len:]
                prefix_block_tables.append(computed_block_nums)
211
212
            else:
                prefix_block_tables.append([])
213
214
215
216
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
                assert computed_len == 0

217
            # actual prompt lens
218
            context_lens.append(computed_len)
219
            subquery_lens.append(prompt_len - computed_len)
220

221
            input_tokens.extend(prompt_tokens)
222
223
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
224
            input_positions.extend(list(range(computed_len, prefill_end)))
225

226
227
228
229
230
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

231
            lora_index_mapping += [lora_id] * (prompt_len - computed_len)
232
233
            lora_prompt_mapping.extend(
                [lora_id] *
234
                (prompt_len - computed_len
235
236
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

237
238
239
240
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

241
242
243
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
244
                slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
245
246
247
248
249
250
251
252
253
254
255
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
256
                assert computed_len == 0, (
257
258
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
259
                start_idx = max(0, prompt_len - self.sliding_window)
260
261

            for i in range(computed_len, prefill_end):
262
                if i < start_idx:
263
                    slot_mapping.append(_PAD_SLOT_ID)
264
265
266
267
268
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
269
270
271
                slot_mapping.append(slot)

        max_subquery_len = max(subquery_lens)
272
        max_prompt_len = max(prompt_lens)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        num_prompt_tokens = len(input_tokens)
        assert max_subquery_len > 0

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
        lora_index_mapping = lora_index_mapping

287
288
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
289
                                           device=self.device)
290
291
292
293
294
295
296
297
298
299

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

300
301
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
302
        block_tables = make_tensor_with_pad(
303
304
305
306
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
307
            device=self.device,
308
        )
309
310
311
312
313
314
315
316
317
318

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
        subquery_lens_tensor = torch.tensor(subquery_lens,
                                            dtype=torch.long,
                                            device=self.device)
        subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
                                         dtype=torch.int32,
                                         device=self.device)

319
320
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
321
                                          device=self.device)
322
323
324
325
326
327
328
329
330
331
332
333
334
        seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(subquery_lens_tensor,
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

        torch.cumsum(prompt_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
335

336
        attn_metadata = self.attn_backend.make_metadata(
337
            is_prompt=True,
338
            slot_mapping=slot_mapping,
339
340
341
342
343
            prompt_lens=prompt_lens,
            prompt_lens_tensor=prompt_lens_tensor,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=0,
            max_subquery_len=max_subquery_len,
344
            max_context_len=None,
345
            max_prompt_len=max_prompt_len,
346
347
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
348
349
            context_lens=context_lens_tensor,
            block_tables=block_tables,
350
            use_cuda_graph=False,
351
            kv_cache_dtype=self.kv_cache_dtype,
352
        )
353
        return (input_tokens, input_positions, attn_metadata, prompt_lens,
354
                subquery_lens, lora_index_mapping, lora_prompt_mapping,
355
                lora_requests, multi_modal_input)
356
357
358
359

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
360
361
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               List[int], Set[LoRARequest]]:
362
        assert len(seq_group_metadata_list) > 0
363
364
365
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
366
367
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
368
369
370
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
371
372
373

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
374
            assert seq_group_metadata.token_chunk_size == 1
375
376

            seq_ids = list(seq_group_metadata.seq_data.keys())
377
378
379
380
381
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

382
383
384
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
385
                input_tokens.append(generation_token)
386

387
388
                seq_len = seq_data.get_len()
                position = seq_len - 1
389
                input_positions.append(position)
390

391
392
393
394
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

395
396
397
398
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
399
400
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
401
                lora_prompt_mapping.append(lora_id)
402
403
404
405
406
407
408

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

409
410
411
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
412
413
414
415
416
417
418
419
420
421
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
422
423
424
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
425
426
                context_lens.append(1)
                block_tables.append([])
427
                lora_index_mapping.append(0)
428
429
            batch_size = graph_batch_size

430
431
432
433
434
435
436
437
438
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
439
440
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
441
                                    device=self.device)
442
443

        if use_captured_graph:
444
445
446
447
448
449
            # When using cuda-graph all these tensors should be
            # padded.
            assert context_lens.shape[0] == input_tokens.shape[0]
            assert context_lens.shape[0] == input_positions.shape[0]
            assert context_lens.shape[0] == slot_mapping.shape[0]

450
451
452
453
454
455
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
456
            block_tables = torch.tensor(input_block_tables, device=self.device)
457
        else:
458
459
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
460
            block_tables = make_tensor_with_pad(
461
                block_tables,
462
                max_len=max_block_table_len,
463
464
                pad=0,
                dtype=torch.int,
465
                device=self.device,
466
            )
467

468
        attn_metadata = self.attn_backend.make_metadata(
469
            is_prompt=False,
470
            slot_mapping=slot_mapping,
471
            prompt_lens=None,
472
473
474
475
            prompt_lens_tensor=None,
            num_prompt_tokens=0,
            num_generation_tokens=len(input_tokens),
            max_subquery_len=None,
476
            max_context_len=max_context_len,
477
            max_prompt_len=None,
478
479
            subquery_start_loc=None,
            seq_start_loc=None,
480
481
            context_lens=context_lens,
            block_tables=block_tables,
482
            use_cuda_graph=use_captured_graph,
483
            kv_cache_dtype=self.kv_cache_dtype,
484
        )
485
        return (input_tokens, input_positions, attn_metadata,
486
                lora_index_mapping, lora_prompt_mapping, lora_requests)
487
488
489
490
491

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
492
        subquery_lens: Optional[List[int]],
493
494
495
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
496
        generators: List[torch.Generator] = []
497
498
499
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
500
        categorized_sampled_token_indices_start_idx = 0
501
502
503
504
505
506
507
508

        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
509
510
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
511
512
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
513
                    categorized_sample_indices_start_idx += subquery_len - 1
514
515

                categorized_sample_indices[
516
517
518
519
                    sampling_params.sampling_type].append([
                        categorized_sample_indices_start_idx,
                        categorized_sampled_token_indices_start_idx
                    ])
520
                categorized_sample_indices_start_idx += 1
521
                categorized_sampled_token_indices_start_idx += 1
522
523
524
525

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
526
                              selected_token_start_idx + subquery_len - 1))
527
                selected_token_indices.append(selected_token_start_idx +
528
                                              subquery_len - 1)
529
                selected_token_start_idx += subquery_len
Nick Hill's avatar
Nick Hill committed
530
531
532

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
533
                        device=self.device).manual_seed(sampling_params.seed)
534
535
536
537
538
539
540
541
542
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
543
544
545
546
547
548
549
550
551
                        zip(
                            range(
                                categorized_sample_indices_start_idx,
                                categorized_sample_indices_start_idx +
                                num_seqs),
                            range(
                                categorized_sampled_token_indices_start_idx,
                                categorized_sampled_token_indices_start_idx +
                                num_seqs)))
552
                categorized_sample_indices_start_idx += num_seqs
553
                categorized_sampled_token_indices_start_idx += num_seqs
554

Nick Hill's avatar
Nick Hill committed
555
556
557
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

558
559
560
561
        selected_token_indices = async_tensor_h2d(selected_token_indices,
                                                  dtype=torch.long,
                                                  target_device=self.device,
                                                  pin_memory=self.pin_memory)
562

563
        categorized_sample_indices = {
564
565
566
567
568
            t: maybe_expand_dim(
                async_tensor_h2d(seq_ids,
                                 dtype=torch.int,
                                 target_device=self.device,
                                 pin_memory=self.pin_memory), 2, 2)
569
570
571
572
573
574
575
576
577
578
579
580
581
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
582
            generators=generators,
583
584
585
        )
        return sampling_metadata

586
587
588
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
589
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
590
               Set[int], LoRAMapping, torch.Tensor]:
591
592
593
594
595
596
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
597
                (input_tokens, input_positions, attn_metadata, prompt_lens,
598
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
599
600
                 lora_requests, multi_modal_input
                 ) = self._prepare_prompt(seq_group_metadata_list)
601
            else:
602
                (input_tokens, input_positions, attn_metadata,
603
604
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
605
                prompt_lens = []
606
                subquery_lens = None
607
                multi_modal_input = None
608
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
609
610
                                                     prompt_lens,
                                                     subquery_lens)
611

612
613
            if self.lora_config:
                lora_mapping = LoRAMapping(
614
                    lora_index_mapping,
615
616
617
618
619
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

620
621
622
623
624
625
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
626
627
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
628
                "multi_modal_input": multi_modal_input,
629
            }
630
            metadata_dict.update(attn_metadata.asdict_zerocopy())
631
            broadcast_tensor_dict(metadata_dict, src=0)
632
        else:
633
            metadata_dict = broadcast_tensor_dict(src=0)
634
635
636
637
638
639
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
640
            multi_modal_input = metadata_dict.pop("multi_modal_input")
641
            attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
642
643
644
645
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
646
                selected_token_indices=selected_token_indices,
647
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
648
                generators=None,
649
650
651
                perform_sampling=False,
            )

652
        return (input_tokens, input_positions, attn_metadata,
653
654
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
655

656
657
658
    @torch.inference_mode()
    def execute_model(
        self,
659
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
660
        kv_caches: List[torch.Tensor],
661
    ) -> Optional[SamplerOutput]:
662
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
663
664
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
665
666
667
668

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

669
        # Execute the model.
670
        if attn_metadata.use_cuda_graph:
671
672
673
674
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
675
676
677
678
679
680
681
682
683
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
684

685
686
687
688
689
690
691
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

692
693
        # Sample the next token.
        output = self.model.sample(
694
            logits=logits,
695
696
697
698
699
700
701
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
702
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
703
704
705
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

728
729
730
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
731
732
733
734
735
736
737
738
739
740
741
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
742
743
744
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
745
746
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
747
748
749
750
751
752
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
753
754
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
755
                multi_modal_data=fake_multi_modal_input,
756
757
758
759
760
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
761
        kv_caches = [None] * num_layers
762
        self.execute_model(seqs, kv_caches)
763
        torch.cuda.synchronize()
764
765
        return

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

792
    @torch.inference_mode()
793
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
794
795
796
797
798
799
800
801
802
803
804
805
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
806
807
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
808
        self.pynccl_backend = pynccl_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
809

810
811
812
813
814
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
815
816
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
817
818
819
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
820
821
822
823
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
824
825
826
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
827
828
829
830
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

831
832
833
834
835
836
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
837
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
838
839
        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
Woosuk Kwon's avatar
Woosuk Kwon committed
840
841
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
842
        # to PyTorch or pynccl if it is disabled or not supported.
843
        with custom_all_reduce.capture():
844
845
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
846
            for batch_size in reversed(batch_size_capture_list):
847
848
                # Create dummy attn_metadata.
                attn_metadata = self.attn_backend.make_metadata(
849
850
851
                    is_prompt=False,
                    slot_mapping=slot_mapping[:batch_size],
                    prompt_lens=None,
852
853
854
855
                    prompt_lens_tensor=None,
                    num_prompt_tokens=0,
                    num_generation_tokens=batch_size,
                    max_subquery_len=None,
856
                    max_context_len=self.max_context_len_to_capture,
857
                    max_prompt_len=None,
858
859
                    subquery_start_loc=None,
                    seq_start_loc=None,
860
861
862
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
863
                    kv_cache_dtype=self.kv_cache_dtype,
864
                )
865

866
867
868
869
870
871
872
873
874
875
876
877
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
878
                    attn_metadata,
879
                    memory_pool=self.graph_memory_pool,
880
                )
881
882
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
883
884
885
886
887
888

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
889
    def __del__(self) -> None:
890
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
891
892
893
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
894
895
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
896
        self.graph_runners.clear()
897
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
898

899
900
901
902
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

903
904
905
906
907
908
909
910
911
912
913
914
915

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
916
917
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
918
        memory_pool,
919
        **kwargs,
920
921
922
923
924
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
925
        with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
926
            self.model(
927
928
929
                input_ids,
                positions,
                kv_caches,
930
                attn_metadata,
931
                **kwargs,
932
933
934
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
935
936
937
938
939
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
940
            with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
941
942
943
944
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
945
                    attn_metadata,
946
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
947
948
949
                )
        torch.cuda.synchronize()

950
951
952
953
954
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
955
956
957
            "slot_mapping": attn_metadata.slot_mapping,
            "context_lens": attn_metadata.context_lens,
            "block_tables": attn_metadata.block_tables,
958
959
960
961
962
963
964
965
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
966
967
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
968
        **kwargs,
969
970
971
972
973
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
974
975
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
976
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
977
                                                 non_blocking=True)
978
        self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
979
                                                 non_blocking=True)
980
        self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
981
                                                 non_blocking=True)
982
983
984
985
986
987
988
989
990
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

991

992
@contextlib.contextmanager
993
994
995
996
def _maybe_pynccl():
    if pynccl_utils.is_initialized(
    ) and not custom_all_reduce.is_initialized():
        with with_pynccl_for_all_reduce():
997
998
999
1000
1001
            yield
    else:
        yield


1002
def _get_graph_batch_size(batch_size: int) -> int:
1003
1004
1005
1006
1007
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1008
1009
1010
1011
1012
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1013
1014
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input