model_runner.py 36.4 KB
Newer Older
1
import time
2
from typing import Dict, List, Optional, Tuple, Set, Union
3

4
import numpy as np
5
import torch
6
import torch.nn as nn
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
                         SchedulerConfig)
10
11
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
12
from vllm.model_executor.parallel_utils.communication_op import (
13
    broadcast_tensor_dict)
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend
from vllm.model_executor.parallel_utils.parallel_state import (
    with_cupy_nccl_for_all_reduce)
17
from vllm.model_executor.parallel_utils import custom_all_reduce
18
19
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
20
21
22
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
23
from vllm.utils import in_wsl
24
25
26

logger = init_logger(__name__)

27
KVCache = Tuple[torch.Tensor, torch.Tensor]
28
_PAD_SLOT_ID = -1
29
LORA_WARMUP_RANK = 8
30
31
32
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
33
34
35
36
37
38
39
40
41


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
42
        device_config: DeviceConfig,
43
        lora_config: Optional[LoRAConfig],
44
        kv_cache_dtype: Optional[str] = "auto",
45
        is_driver_worker: bool = False,
46
47
48
49
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
50
        self.lora_config = lora_config
51
        self.is_driver_worker = is_driver_worker
52

Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
57
58
59
60
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

61
62
        self.model = None
        self.block_size = None  # Set after initial profiling.
63
        self.lora_manager = None
64

65
66
67
68
69
70
71
72
73
74
75
76
77
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
78
79
        # cache in_wsl result
        self.in_wsl = in_wsl()
80
        self.kv_cache_dtype = kv_cache_dtype
81

82
    def load_model(self) -> None:
83
84
        self.model = get_model(self.model_config, self.device_config,
                               self.lora_config)
85
86
87
88
89
90
91
92
93
94

        vocab_size = self.model.config.vocab_size

        if self.lora_config:
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens +
                self.scheduler_config.max_paddings, vocab_size,
                self.lora_config, self.device)
            self.model = self.lora_manager.create_lora_manager(self.model)
95
96
97
98

    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

99
100
101
102
103
        max_num_blocks = (self.max_context_len_to_capture + block_size -
                          1) // block_size
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)

104
105
106
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
107
108
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
               List[int], List[int], Set[LoRARequest]]:
109
110
111
112
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
113
114
115
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
116
117

        prompt_lens: List[int] = []
118
119
120
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
121
122
123
124
125
126
127
128
129
130
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)
131
132
133
134
135
136
137
138
139
140
141
            prefix_len = 0
            prefix = seq_group_metadata.prefix
            if prefix is not None and prefix.computed:
                prefix_len = prefix.get_length()
                prompt_tokens = prompt_tokens[prefix_len:]
                prefix_block_tables.append(prefix.get_block_numbers())
            else:
                prefix_block_tables.append([])
            # actual prompt lens
            context_lens.append(prefix_len)
            subquery_lens.append(prompt_len - prefix_len)
142
143
144
145

            input_tokens.append(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
146
147
            input_positions.append(
                list(range(prefix_len, prefix_len + len(prompt_tokens))))
148

149
150
151
152
153
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

154
            lora_index_mapping.append([lora_id] * (prompt_len - prefix_len))
155
156
            lora_prompt_mapping.extend(
                [lora_id] *
157
                (prompt_len - prefix_len
158
159
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
                slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
                continue

            # Compute the slot mapping.
            slot_mapping.append([])
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
176
177
178
                assert prefix_len == 0, (
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
179
                start_idx = max(0, prompt_len - self.sliding_window)
180
            for i in range(prefix_len, prompt_len):
181
182
183
184
185
186
187
188
189
                if i < start_idx:
                    slot_mapping[-1].append(_PAD_SLOT_ID)
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping[-1].append(slot)

190
        max_prompt_len = max(subquery_lens)
191
192
193
        input_tokens = _make_tensor_with_pad(input_tokens,
                                             max_prompt_len,
                                             pad=0,
194
195
                                             dtype=torch.long,
                                             device=self.device)
196
197
198
        input_positions = _make_tensor_with_pad(input_positions,
                                                max_prompt_len,
                                                pad=0,
199
200
                                                dtype=torch.long,
                                                device=self.device)
201
202
203
        slot_mapping = _make_tensor_with_pad(slot_mapping,
                                             max_prompt_len,
                                             pad=_PAD_SLOT_ID,
204
205
                                             dtype=torch.long,
                                             device=self.device)
206
207
208
209
        lora_index_mapping = [
            _pad_to_max(mapping, max_prompt_len, pad=0)
            for mapping in lora_index_mapping
        ]
210
211
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
212
                                           device=self.device)
213
214
215
216
217
218
219
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
        block_tables = _make_tensor_with_pad(
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
220
            device=self.device,
221
222
223
224
225
        )
        start_loc_tensor = torch.arange(0,
                                        len(prompt_lens) * max_prompt_len,
                                        max_prompt_len,
                                        dtype=torch.long,
226
                                        device=self.device)
227
228
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
229
                                          device=self.device)
230
231

        input_metadata = InputMetadata(
232
            is_prompt=True,
233
            slot_mapping=slot_mapping,
234
235
236
            prompt_lens=prompt_lens_tensor,
            max_seq_len=max_prompt_len,
            start_loc=start_loc_tensor,
237
            max_context_len=None,
238
239
            context_lens=context_lens_tensor,
            block_tables=block_tables,
240
            use_cuda_graph=False,
241
            kv_cache_dtype=self.kv_cache_dtype,
242
        )
243
        return (input_tokens, input_positions, input_metadata, prompt_lens,
244
245
                subquery_lens, lora_index_mapping, lora_prompt_mapping,
                lora_requests)
246
247
248
249

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
250
251
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
               Set[LoRARequest]]:
252
253
254
255
256
257
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
258
259
260
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
261
262
263
264
265

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())
266
267
268
269
270
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

271
272
273
274
275
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

276
277
                seq_len = seq_data.get_len()
                position = seq_len - 1
278
279
                input_positions.append([position])

280
281
282
283
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

284
285
286
287
288
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append([slot])
289
290
                lora_index_mapping.append([lora_id])
                lora_prompt_mapping.append(lora_id)
291
292
293
294
295
296
297

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            # Pad the input tokens, positions, and slot mapping to match the
            # batch size of the captured graph.
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
                input_tokens.append([])
                input_positions.append([])
                slot_mapping.append([])
                context_lens.append(1)
                block_tables.append([])
            batch_size = graph_batch_size

317
318
319
        input_tokens = _make_tensor_with_pad(input_tokens,
                                             max_len=1,
                                             pad=0,
320
                                             dtype=torch.long,
321
                                             device=self.device)
322
323
324
        input_positions = _make_tensor_with_pad(input_positions,
                                                max_len=1,
                                                pad=0,
325
                                                dtype=torch.long,
326
                                                device=self.device)
327
328
329
        slot_mapping = _make_tensor_with_pad(slot_mapping,
                                             max_len=1,
                                             pad=_PAD_SLOT_ID,
330
                                             dtype=torch.long,
331
                                             device=self.device)
332
333
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
334
                                    device=self.device)
335
336
337
338
339
340
341
342

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
343
            block_tables = torch.tensor(input_block_tables, device=self.device)
344
        else:
345
346
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
347
348
            block_tables = _make_tensor_with_pad(
                block_tables,
349
                max_len=max_block_table_len,
350
351
                pad=0,
                dtype=torch.int,
352
                device=self.device,
353
            )
354

355
356
357
358
        lora_index_mapping = [
            _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
        ]

359
        input_metadata = InputMetadata(
360
            is_prompt=False,
361
            slot_mapping=slot_mapping,
362
363
364
            prompt_lens=None,
            max_seq_len=None,
            start_loc=None,
365
366
367
            max_context_len=max_context_len,
            context_lens=context_lens,
            block_tables=block_tables,
368
            use_cuda_graph=use_captured_graph,
369
            kv_cache_dtype=self.kv_cache_dtype,
370
        )
371
372
        return (input_tokens, input_positions, input_metadata,
                lora_index_mapping, lora_prompt_mapping, lora_requests)
373
374
375
376
377

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
378
        subquery_lens: Optional[List[int]],
379
380
381
382
383
384
385
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0

386
        max_subquery_len = max(subquery_lens) if subquery_lens else 1
387
388
389
390
391
392
393
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
394
395
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
396
397
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
398
                    categorized_sample_indices_start_idx += subquery_len - 1
399
400
401
402
403
404
405
406
407

                categorized_sample_indices[
                    sampling_params.sampling_type].append(
                        categorized_sample_indices_start_idx)
                categorized_sample_indices_start_idx += 1

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
408
                              selected_token_start_idx + subquery_len - 1))
409
                selected_token_indices.append(selected_token_start_idx +
410
411
                                              subquery_len - 1)
                selected_token_start_idx += max_subquery_len
412
413
414
415
416
417
418
419
420
421
422
423
424
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
                        range(categorized_sample_indices_start_idx,
                              categorized_sample_indices_start_idx + num_seqs))
                categorized_sample_indices_start_idx += num_seqs

425
426
        selected_token_indices = _async_h2d(selected_token_indices,
                                            dtype=torch.long,
427
                                            target_device=self.device,
428
                                            pin_memory=not self.in_wsl)
429
        categorized_sample_indices = {
430
431
432
433
            t: _async_h2d(seq_ids,
                          dtype=torch.int,
                          target_device=self.device,
                          pin_memory=not self.in_wsl)
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
        )
        return sampling_metadata

450
451
452
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
453
454
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
               Set[int], LoRAMapping]:
455
456
457
458
459
460
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
461
                (input_tokens, input_positions, input_metadata, prompt_lens,
462
463
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_prompt(seq_group_metadata_list)
464
            else:
465
466
467
                (input_tokens, input_positions, input_metadata,
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
468
                prompt_lens = []
469
                subquery_lens = None
470
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
471
472
                                                     prompt_lens,
                                                     subquery_lens)
473

474
475
476
477
478
479
480
481
482
483
484
            if self.lora_config:
                flat_lora_index_mapping = [
                    item for sublist in lora_index_mapping for item in sublist
                ]
                lora_mapping = LoRAMapping(
                    flat_lora_index_mapping,
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

485
486
487
488
489
490
491
492
493
494
495
496
497
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "is_prompt": input_metadata.is_prompt,
                "slot_mapping": input_metadata.slot_mapping,
                "prompt_lens": input_metadata.prompt_lens,
                "max_seq_len": input_metadata.max_seq_len,
                "start_loc": input_metadata.start_loc,
                "max_context_len": input_metadata.max_context_len,
                "context_lens": input_metadata.context_lens,
                "block_tables": input_metadata.block_tables,
                "use_cuda_graph": input_metadata.use_cuda_graph,
498
                "kv_cache_dtype": input_metadata.kv_cache_dtype,
499
500
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
501
502
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
503
            }
504
            broadcast_tensor_dict(metadata_dict, src=0)
505
        else:
506
507
508
            metadata_dict = broadcast_tensor_dict(src=0)
            input_tokens = metadata_dict["input_tokens"]
            input_positions = metadata_dict["input_positions"]
509
510
            lora_mapping = metadata_dict["lora_mapping"]
            lora_requests = metadata_dict["lora_requests"]
511
            input_metadata = InputMetadata(
512
513
514
515
516
517
518
519
520
                is_prompt=metadata_dict["is_prompt"],
                slot_mapping=metadata_dict["slot_mapping"],
                prompt_lens=metadata_dict["prompt_lens"],
                max_seq_len=metadata_dict["max_seq_len"],
                start_loc=metadata_dict["start_loc"],
                max_context_len=metadata_dict["max_context_len"],
                context_lens=metadata_dict["context_lens"],
                block_tables=metadata_dict["block_tables"],
                use_cuda_graph=metadata_dict["use_cuda_graph"],
521
                kv_cache_dtype=metadata_dict["kv_cache_dtype"],
522
523
524
525
526
            )
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
527
                selected_token_indices=metadata_dict["selected_token_indices"],
528
529
530
531
                categorized_sample_indices=None,
                perform_sampling=False,
            )

532
533
        return (input_tokens, input_positions, input_metadata,
                sampling_metadata, lora_requests, lora_mapping)
534

535
536
537
    @torch.inference_mode()
    def execute_model(
        self,
538
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
539
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
540
    ) -> Optional[SamplerOutput]:
541
542
543
        (input_tokens, input_positions, input_metadata, sampling_metadata,
         lora_requests,
         lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
544
545
546
547

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

548
        # Execute the model.
549
550
551
552
553
554
        if input_metadata.use_cuda_graph:
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
        hidden_states = model_executable(
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            input_metadata=input_metadata,
        )

        # Sample the next token.
        output = self.model.sample(
            hidden_states=hidden_states,
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
        vocab_size = self.model_config.get_vocab_size()
        sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

598
599
600
601
602
603
604
605
606
607
608
609
610
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
611
612
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
613
614
615
616
617
618
619
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        kv_caches = [(None, None)] * num_layers
        self.execute_model(seqs, kv_caches)
620
        torch.cuda.synchronize()
621
622
        return

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

649
650
    @torch.inference_mode()
    def capture_model(self, kv_caches: List[KVCache]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
651
652
653
654
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
        self.cupy_nccl_backend = get_nccl_backend()

655
656
657
658
659
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
660
661
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
662
663
664
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
665
666
667
668
669
670
671
672
673
674
675
676
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
        input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, 1,
                                      dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

677
678
679
680
681
682
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

683
684
        # NOTE: Capturing the largest batch size first may help reduce the
        # memory usage of CUDA graph.
Woosuk Kwon's avatar
Woosuk Kwon committed
685
686
687
688
689
690
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
        # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
        # to PyTorch or CuPy NCCL if it is disabled or not supported.
691
692
693
694
695
696
697
698
699
700
701
702
703
        with custom_all_reduce.capture():
            for batch_size in reversed(batch_size_capture_list):
                # Create dummy input_metadata.
                input_metadata = InputMetadata(
                    is_prompt=False,
                    slot_mapping=slot_mapping[:batch_size],
                    prompt_lens=None,
                    max_seq_len=None,
                    start_loc=None,
                    max_context_len=self.max_context_len_to_capture,
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
704
                    kv_cache_dtype=self.kv_cache_dtype,
705
                )
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
                    input_metadata,
                    memory_pool=self.graph_memory_pool,
721
                )
722
723
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
724
725
726
727
728
729

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
730
731
732
733
734
735
736
737
    def __del__(self) -> None:
        # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
        self.graph_runners.clear()
        self.cupy_nccl_backend = None

738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        memory_pool,
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
Woosuk Kwon's avatar
Woosuk Kwon committed
759
760
        with with_cupy_nccl_for_all_reduce():
            self.model(
761
762
763
764
765
766
767
                input_ids,
                positions,
                kv_caches,
                input_metadata,
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
            with with_cupy_nccl_for_all_reduce():
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
                    input_metadata,
                )
        torch.cuda.synchronize()

782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
            "slot_mapping": input_metadata.slot_mapping,
            "context_lens": input_metadata.context_lens,
            "block_tables": input_metadata.block_tables,
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
805
806
807
808
809
810
811
812
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
        self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
                                                 non_blocking=True)
        self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
                                                 non_blocking=True)
        self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
                                                 non_blocking=True)
813
814
815
816
817
818
819
820
821
822

        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

823
824
825
826
827
828
829
830
831
832
833

def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
    assert len(x) <= max_len
    return x + [pad] * (max_len - len(x))


def _make_tensor_with_pad(
    x: List[List[int]],
    max_len: int,
    pad: int,
    dtype: torch.dtype,
834
    device: Optional[Union[str, torch.device]],
835
836
) -> torch.Tensor:
    padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
837
    return torch.tensor(padded_x, dtype=dtype, device=device)
838
839
840
841
842
843
844
845
846


def _get_graph_batch_size(batch_size: int) -> int:
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
        return (batch_size + 7) // 8 * 8
847
848


849
850
851
852
853
854
855
856
def _async_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)