"vllm/vscode:/vscode.git/clone" did not exist on "5d5b4c5fe524c3b62453bba7ad4434a27c81317a"
model_runner.py 44.3 KB
Newer Older
1
import contextlib
2
import time
3
from typing import Dict, List, Optional, Set, Tuple
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

9
from vllm.attention import AttentionMetadata, get_attn_backend
10
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
11
                         SchedulerConfig, VisionLanguageConfig)
12
13
14
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
                                                   pynccl_utils)
15
from vllm.logger import init_logger
16
17
18
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
19
from vllm.model_executor import SamplingMetadata
20
from vllm.model_executor.model_loader import get_model
21
from vllm.sampling_params import SamplingParams, SamplingType
22
23
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
24
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
25
26
                        is_pin_memory_available, make_tensor_with_pad,
                        maybe_expand_dim)
27
28
29
30

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
31
LORA_WARMUP_RANK = 8
32
33
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
34
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
35
36
37
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
38
39
40
41
42
43
44
45
46


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
47
        device_config: DeviceConfig,
48
        lora_config: Optional[LoRAConfig],
49
        kv_cache_dtype: Optional[str] = "auto",
50
        is_driver_worker: bool = False,
51
        vision_language_config: Optional[VisionLanguageConfig] = None,
52
53
54
55
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
56
        self.lora_config = lora_config
57
        self.is_driver_worker = is_driver_worker
58

Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
63
64
65
66
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

67
68
        self.model = None
        self.block_size = None  # Set after initial profiling.
69
        self.lora_manager = None
70

71
72
73
74
75
76
77
78
79
80
81
82
83
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
84
        self.pin_memory = is_pin_memory_available()
85
        self.kv_cache_dtype = kv_cache_dtype
86
        self.vision_language_config = vision_language_config
87

88
89
90
        self.attn_backend = get_attn_backend(
            self.model_config.dtype if model_config is not None else None)

91
    def load_model(self) -> None:
92
        with CudaMemoryProfiler() as m:
93
94
95
96
97
98
99
            self.model = get_model(
                self.model_config,
                self.device_config,
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config)
100
101

        self.model_memory_usage = m.consumed_memory
102
103
        logger.info(f"Loading model weights took "
                    f"{self.model_memory_usage / float(2**30):.4f} GB")
104
105

        if self.lora_config:
106
107
108
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
109
110
111
112
113
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
114
115
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
116
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
117
118
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
119
            self.model = self.lora_manager.create_lora_manager(self.model)
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        if self.kv_cache_dtype == "fp8" and is_hip():
            # Currently scaled KV cache is only enabled on ROCm
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
                else:
                    raise RuntimeError("Using FP8 KV cache and scaling "
                                       "factors provided but model "
                                       f"{self.model.__class__} does not "
                                       "support loading scaling factors.")
            else:
                logger.warn("Using FP8 KV cache but no scaling factors "
                            "provided. Defaulting to scaling factors of 1.0. "
                            "This may lead to less accurate results!")
        elif self.model_config.quantization_param_path is not None:
            logger.warn("KV cache scaling factors provided, "
                        "but the KV cache data type is not FP8. "
                        "KV cache scaling factors will not be used.")

141
142
143
    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

144
        self.graph_block_tables = np.zeros(
145
146
147
148
149
150
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
        return (self.max_context_len_to_capture + block_size - 1) // block_size
151

152
153
154
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
155
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
156
157
               List[int], List[int], List[int], Set[LoRARequest],
               torch.Tensor]:
158
        assert len(seq_group_metadata_list) > 0
159
160
161
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
162
163
164
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
165
166

        prompt_lens: List[int] = []
167
168
169
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
170
        multi_modal_input_list: List[torch.Tensor] = []
171

172
173
174
175
176
177
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

178
179
180
181
182
183
184
185
186
            computed_block_nums = seq_group_metadata.computed_block_nums
            if (self.scheduler_config is not None
                    and self.scheduler_config.chunked_prefill_enabled
                    and computed_block_nums is not None):
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching "
                    "now.")

            token_chunk_size = seq_group_metadata.token_chunk_size
187
            seq_data = seq_group_metadata.seq_data[seq_id]
188
189
190
191
192
193
194
            computed_len = seq_data.get_num_computed_tokens()
            # We should use get_len here because in case of preemption
            # it contains output tokens.
            prefill_end = min(seq_data.get_len(),
                              computed_len + token_chunk_size)
            # TODO(sang): Rename it after chunked prefill is introduced.
            prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
195
            prompt_len = len(prompt_tokens)
196
197
198
199
            # Right now, the prefill_end is always same as the length of
            # sequence. However, once chunked prefill is introduced, this
            # assumption can be changed.
            assert prefill_end == seq_data.get_len()
200
            prompt_lens.append(prompt_len)
201
202
203
204
205
206
207
208

            # NOTE: This only works for oooooooxxx style attention.
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
                computed_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[computed_len:]
                prefix_block_tables.append(computed_block_nums)
209
210
            else:
                prefix_block_tables.append([])
211
212
213
214
                # Right now, prefill start is always 0. However, this
                # assumption can be changed once chunked prefill is introduced.
                assert computed_len == 0

215
            # actual prompt lens
216
            context_lens.append(computed_len)
217
            subquery_lens.append(prompt_len - computed_len)
218

219
            input_tokens.extend(prompt_tokens)
220
221
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
222
            input_positions.extend(list(range(computed_len, prefill_end)))
223
224
225
226
227
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

228
            lora_index_mapping += [lora_id] * (prompt_len - computed_len)
229
230
            lora_prompt_mapping.extend(
                [lora_id] *
231
                (prompt_len - computed_len
232
233
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

234
235
236
237
            if seq_group_metadata.multi_modal_data:
                multi_modal_input_list.append(
                    seq_group_metadata.multi_modal_data.data)

238
239
240
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
241
                slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
242
243
244
245
246
247
248
249
250
251
252
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
253
                assert computed_len == 0, (
254
255
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
256
                start_idx = max(0, prompt_len - self.sliding_window)
257
258

            for i in range(computed_len, prefill_end):
259
                if i < start_idx:
260
                    slot_mapping.append(_PAD_SLOT_ID)
261
262
263
264
265
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
266
267
268
                slot_mapping.append(slot)

        max_subquery_len = max(subquery_lens)
269
        max_prompt_len = max(prompt_lens)
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        num_prompt_tokens = len(input_tokens)
        assert max_subquery_len > 0

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
        lora_index_mapping = lora_index_mapping

284
285
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
286
                                           device=self.device)
287
288
289
290
291
292
293
294
295
296

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

297
298
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
299
        block_tables = make_tensor_with_pad(
300
301
302
303
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
304
            device=self.device,
305
        )
306
307
308
309
310
311
312
313
314
315

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
        subquery_lens_tensor = torch.tensor(subquery_lens,
                                            dtype=torch.long,
                                            device=self.device)
        subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
                                         dtype=torch.int32,
                                         device=self.device)

316
317
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
318
                                          device=self.device)
319
320
321
322
323
324
325
326
327
328
329
330
331
        seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(subquery_lens_tensor,
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

        torch.cumsum(prompt_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
332

333
        attn_metadata = self.attn_backend.make_metadata(
334
            is_prompt=True,
335
            slot_mapping=slot_mapping,
336
337
338
339
340
            prompt_lens=prompt_lens,
            prompt_lens_tensor=prompt_lens_tensor,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=0,
            max_subquery_len=max_subquery_len,
341
            max_context_len=None,
342
            max_prompt_len=max_prompt_len,
343
344
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
345
346
            context_lens=context_lens_tensor,
            block_tables=block_tables,
347
            use_cuda_graph=False,
348
            kv_cache_dtype=self.kv_cache_dtype,
349
        )
350
        return (input_tokens, input_positions, attn_metadata, prompt_lens,
351
                subquery_lens, lora_index_mapping, lora_prompt_mapping,
352
                lora_requests, multi_modal_input)
353
354
355
356

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
357
358
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               List[int], Set[LoRARequest]]:
359
        assert len(seq_group_metadata_list) > 0
360
361
362
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
363
364
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
365
366
367
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
368
369
370

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
371
            assert seq_group_metadata.token_chunk_size == 1
372
373

            seq_ids = list(seq_group_metadata.seq_data.keys())
374
375
376
377
378
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

379
380
381
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
382
                input_tokens.append(generation_token)
383

384
385
                seq_len = seq_data.get_len()
                position = seq_len - 1
386
                input_positions.append(position)
387

388
389
390
391
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

392
393
394
395
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
396
397
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
398
                lora_prompt_mapping.append(lora_id)
399
400
401
402
403
404
405

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

406
407
408
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
409
410
411
412
413
414
415
416
417
418
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
419
420
421
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
422
423
                context_lens.append(1)
                block_tables.append([])
424
                lora_index_mapping.append(0)
425
426
            batch_size = graph_batch_size

427
428
429
430
431
432
433
434
435
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
436
437
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
438
                                    device=self.device)
439
440

        if use_captured_graph:
441
442
443
444
445
446
            # When using cuda-graph all these tensors should be
            # padded.
            assert context_lens.shape[0] == input_tokens.shape[0]
            assert context_lens.shape[0] == input_positions.shape[0]
            assert context_lens.shape[0] == slot_mapping.shape[0]

447
448
449
450
451
452
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
453
            block_tables = torch.tensor(input_block_tables, device=self.device)
454
        else:
455
456
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
457
            block_tables = make_tensor_with_pad(
458
                block_tables,
459
                max_len=max_block_table_len,
460
461
                pad=0,
                dtype=torch.int,
462
                device=self.device,
463
            )
464

465
        attn_metadata = self.attn_backend.make_metadata(
466
            is_prompt=False,
467
            slot_mapping=slot_mapping,
468
            prompt_lens=None,
469
470
471
472
            prompt_lens_tensor=None,
            num_prompt_tokens=0,
            num_generation_tokens=len(input_tokens),
            max_subquery_len=None,
473
            max_context_len=max_context_len,
474
            max_prompt_len=None,
475
476
            subquery_start_loc=None,
            seq_start_loc=None,
477
478
            context_lens=context_lens,
            block_tables=block_tables,
479
            use_cuda_graph=use_captured_graph,
480
            kv_cache_dtype=self.kv_cache_dtype,
481
        )
482
        return (input_tokens, input_positions, attn_metadata,
483
                lora_index_mapping, lora_prompt_mapping, lora_requests)
484
485
486
487
488

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
489
        subquery_lens: Optional[List[int]],
490
491
492
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
493
        generators: List[torch.Generator] = []
494
495
496
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
497
        categorized_sampled_token_indices_start_idx = 0
498
499
500
501
502
503
504
505

        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
506
507
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
508
509
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
510
                    categorized_sample_indices_start_idx += subquery_len - 1
511
512

                categorized_sample_indices[
513
514
515
516
                    sampling_params.sampling_type].append([
                        categorized_sample_indices_start_idx,
                        categorized_sampled_token_indices_start_idx
                    ])
517
                categorized_sample_indices_start_idx += 1
518
                categorized_sampled_token_indices_start_idx += 1
519
520
521
522

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
523
                              selected_token_start_idx + subquery_len - 1))
524
                selected_token_indices.append(selected_token_start_idx +
525
                                              subquery_len - 1)
526
                selected_token_start_idx += subquery_len
Nick Hill's avatar
Nick Hill committed
527
528
529

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
530
                        device=self.device).manual_seed(sampling_params.seed)
531
532
533
534
535
536
537
538
539
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
540
541
542
543
544
545
546
547
548
                        zip(
                            range(
                                categorized_sample_indices_start_idx,
                                categorized_sample_indices_start_idx +
                                num_seqs),
                            range(
                                categorized_sampled_token_indices_start_idx,
                                categorized_sampled_token_indices_start_idx +
                                num_seqs)))
549
                categorized_sample_indices_start_idx += num_seqs
550
                categorized_sampled_token_indices_start_idx += num_seqs
551

Nick Hill's avatar
Nick Hill committed
552
553
554
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

555
556
557
558
        selected_token_indices = async_tensor_h2d(selected_token_indices,
                                                  dtype=torch.long,
                                                  target_device=self.device,
                                                  pin_memory=self.pin_memory)
559

560
        categorized_sample_indices = {
561
562
563
564
565
            t: maybe_expand_dim(
                async_tensor_h2d(seq_ids,
                                 dtype=torch.int,
                                 target_device=self.device,
                                 pin_memory=self.pin_memory), 2, 2)
566
567
568
569
570
571
572
573
574
575
576
577
578
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
579
            generators=generators,
580
581
582
        )
        return sampling_metadata

583
584
585
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
586
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
587
               Set[int], LoRAMapping, torch.Tensor]:
588
589
590
591
592
593
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
594
                (input_tokens, input_positions, attn_metadata, prompt_lens,
595
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
596
597
                 lora_requests, multi_modal_input
                 ) = self._prepare_prompt(seq_group_metadata_list)
598
            else:
599
                (input_tokens, input_positions, attn_metadata,
600
601
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
602
                prompt_lens = []
603
                subquery_lens = None
604
                multi_modal_input = None
605
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
606
607
                                                     prompt_lens,
                                                     subquery_lens)
608

609
610
            if self.lora_config:
                lora_mapping = LoRAMapping(
611
                    lora_index_mapping,
612
613
614
615
616
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

617
618
619
620
621
622
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
623
624
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
625
                "multi_modal_input": multi_modal_input,
626
            }
627
            metadata_dict.update(attn_metadata.asdict_zerocopy())
628
            broadcast_tensor_dict(metadata_dict, src=0)
629
        else:
630
            metadata_dict = broadcast_tensor_dict(src=0)
631
632
633
634
635
636
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
637
            multi_modal_input = metadata_dict.pop("multi_modal_input")
638
            attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
639
640
641
642
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
643
                selected_token_indices=selected_token_indices,
644
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
645
                generators=None,
646
647
648
                perform_sampling=False,
            )

649
        return (input_tokens, input_positions, attn_metadata,
650
651
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
652

653
654
655
    @torch.inference_mode()
    def execute_model(
        self,
656
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
657
        kv_caches: List[torch.Tensor],
658
    ) -> Optional[SamplerOutput]:
659
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
660
661
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
662
663
664
665

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

666
        # Execute the model.
667
        if attn_metadata.use_cuda_graph:
668
669
670
671
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
672
673
674
675
676
677
678
679
680
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
681

682
683
684
685
686
687
688
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

689
690
        # Sample the next token.
        output = self.model.sample(
691
            logits=logits,
692
693
694
695
696
697
698
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
699
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
700
701
702
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

725
726
727
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
728
729
730
731
732
733
734
735
736
737
738
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
739
740
741
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
742
743
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
744
745
746
747
748
749
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
750
751
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
752
                multi_modal_data=fake_multi_modal_input,
753
754
755
756
757
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
758
        kv_caches = [None] * num_layers
759
        self.execute_model(seqs, kv_caches)
760
        torch.cuda.synchronize()
761
762
        return

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

789
    @torch.inference_mode()
790
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
791
792
793
794
795
796
797
798
799
800
801
802
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
803
804
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
805
        self.pynccl_backend = pynccl_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
806

807
808
809
810
811
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
812
813
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
814
815
816
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
817
818
819
820
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
821
822
823
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
824
825
826
827
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

828
829
830
831
832
833
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
834
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
835
836
        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or pynccl. When not using CUDA
Woosuk Kwon's avatar
Woosuk Kwon committed
837
838
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
839
        # to PyTorch or pynccl if it is disabled or not supported.
840
        with custom_all_reduce.capture():
841
842
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
843
            for batch_size in reversed(batch_size_capture_list):
844
845
                # Create dummy attn_metadata.
                attn_metadata = self.attn_backend.make_metadata(
846
847
848
                    is_prompt=False,
                    slot_mapping=slot_mapping[:batch_size],
                    prompt_lens=None,
849
850
851
852
                    prompt_lens_tensor=None,
                    num_prompt_tokens=0,
                    num_generation_tokens=batch_size,
                    max_subquery_len=None,
853
                    max_context_len=self.max_context_len_to_capture,
854
                    max_prompt_len=None,
855
856
                    subquery_start_loc=None,
                    seq_start_loc=None,
857
858
859
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
860
                    kv_cache_dtype=self.kv_cache_dtype,
861
                )
862

863
864
865
866
867
868
869
870
871
872
873
874
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
875
                    attn_metadata,
876
                    memory_pool=self.graph_memory_pool,
877
                )
878
879
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
880
881
882
883
884
885

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
886
    def __del__(self) -> None:
887
        # Delete the CUDA graphs before deleting the pynccl communicator.
Woosuk Kwon's avatar
Woosuk Kwon committed
888
889
890
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
891
892
        # TODO(youkaichao): when we get enough user feedback that pynccl is
        # more stable than cupy, we can remove this, e.g. in v0.4.1.
Woosuk Kwon's avatar
Woosuk Kwon committed
893
        self.graph_runners.clear()
894
        self.pynccl_backend = None
Woosuk Kwon's avatar
Woosuk Kwon committed
895

896
897
898
899
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

900
901
902
903
904
905
906
907
908
909
910
911
912

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
913
914
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
915
        memory_pool,
916
        **kwargs,
917
918
919
920
921
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
922
        with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
923
            self.model(
924
925
926
                input_ids,
                positions,
                kv_caches,
927
                attn_metadata,
928
                **kwargs,
929
930
931
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
932
933
934
935
936
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
937
            with _maybe_pynccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
938
939
940
941
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
942
                    attn_metadata,
943
                    **kwargs,
Woosuk Kwon's avatar
Woosuk Kwon committed
944
945
946
                )
        torch.cuda.synchronize()

947
948
949
950
951
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
952
953
954
            "slot_mapping": attn_metadata.slot_mapping,
            "context_lens": attn_metadata.context_lens,
            "block_tables": attn_metadata.block_tables,
955
956
957
958
959
960
961
962
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
963
964
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
965
        **kwargs,
966
967
968
969
970
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
971
972
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
973
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
974
                                                 non_blocking=True)
975
        self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
976
                                                 non_blocking=True)
977
        self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
978
                                                 non_blocking=True)
979
980
981
982
983
984
985
986
987
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

988

989
@contextlib.contextmanager
990
991
992
993
def _maybe_pynccl():
    if pynccl_utils.is_initialized(
    ) and not custom_all_reduce.is_initialized():
        with with_pynccl_for_all_reduce():
994
995
996
997
998
            yield
    else:
        yield


999
def _get_graph_batch_size(batch_size: int) -> int:
1000
1001
1002
1003
1004
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1005
1006
1007
1008
1009
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1010
1011
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input