model_runner.py 38 KB
Newer Older
1
import contextlib
2
import time
3
from typing import Dict, List, Optional, Tuple, Set, Union
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
                         SchedulerConfig)
11
12
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
13
from vllm.model_executor.parallel_utils import cupy_utils
14
from vllm.model_executor.parallel_utils.communication_op import (
15
    broadcast_tensor_dict)
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
from vllm.model_executor.parallel_utils.parallel_state import (
    with_cupy_nccl_for_all_reduce)
18
from vllm.model_executor.parallel_utils import custom_all_reduce
19
20
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
21
22
23
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
24
from vllm.utils import in_wsl
25
26
27

logger = init_logger(__name__)

28
KVCache = Tuple[torch.Tensor, torch.Tensor]
29
_PAD_SLOT_ID = -1
30
LORA_WARMUP_RANK = 8
31
32
33
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
34
35
36
37
38
39
40
41
42


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
43
        device_config: DeviceConfig,
44
        lora_config: Optional[LoRAConfig],
45
        kv_cache_dtype: Optional[str] = "auto",
46
        is_driver_worker: bool = False,
47
48
49
50
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
51
        self.lora_config = lora_config
52
        self.is_driver_worker = is_driver_worker
53

Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
58
59
60
61
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

62
63
        self.model = None
        self.block_size = None  # Set after initial profiling.
64
        self.lora_manager = None
65

66
67
68
69
70
71
72
73
74
75
76
77
78
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
79
80
        # cache in_wsl result
        self.in_wsl = in_wsl()
81
        self.kv_cache_dtype = kv_cache_dtype
82

83
84
85
86
        # Set enforce_eager to True for Neuron backend, to avoid capturing graph
        if self.device_config.is_neuron:
            self.model_config.enforce_eager = True

87
    def load_model(self) -> None:
88
89
90
91
92
        self.model = get_model(self.model_config,
                               self.device_config,
                               lora_config=self.lora_config,
                               parallel_config=self.parallel_config,
                               scheduler_config=self.scheduler_config)
93
94
95
96

        vocab_size = self.model.config.vocab_size

        if self.lora_config:
Terry's avatar
Terry committed
97
98
99
100
101
102
103
104
            assert hasattr(
                self.model, "supported_lora_modules"
            ) and self.model.supported_lora_modules, "Model does not support LoRA"
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
105
106
107
108
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens +
                self.scheduler_config.max_paddings, vocab_size,
Terry's avatar
Terry committed
109
110
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
111
            self.model = self.lora_manager.create_lora_manager(self.model)
112
113
114
115

    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

116
117
118
119
120
        max_num_blocks = (self.max_context_len_to_capture + block_size -
                          1) // block_size
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)

121
122
123
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
124
125
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
               List[int], List[int], Set[LoRARequest]]:
126
127
128
129
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
130
131
132
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
133
134

        prompt_lens: List[int] = []
135
136
137
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
138
139
140
141
142
143
144
145
146
147
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)
148
149
150
151
152
153
154
155
156
157
158
            prefix_len = 0
            prefix = seq_group_metadata.prefix
            if prefix is not None and prefix.computed:
                prefix_len = prefix.get_length()
                prompt_tokens = prompt_tokens[prefix_len:]
                prefix_block_tables.append(prefix.get_block_numbers())
            else:
                prefix_block_tables.append([])
            # actual prompt lens
            context_lens.append(prefix_len)
            subquery_lens.append(prompt_len - prefix_len)
159
160
161
162

            input_tokens.append(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
163
164
            input_positions.append(
                list(range(prefix_len, prefix_len + len(prompt_tokens))))
165

166
167
168
169
170
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

171
            lora_index_mapping.append([lora_id] * (prompt_len - prefix_len))
172
173
            lora_prompt_mapping.extend(
                [lora_id] *
174
                (prompt_len - prefix_len
175
176
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
                slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
                continue

            # Compute the slot mapping.
            slot_mapping.append([])
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
193
194
195
                assert prefix_len == 0, (
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
196
                start_idx = max(0, prompt_len - self.sliding_window)
197
            for i in range(prefix_len, prompt_len):
198
199
200
201
202
203
204
205
206
                if i < start_idx:
                    slot_mapping[-1].append(_PAD_SLOT_ID)
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping[-1].append(slot)

207
        max_prompt_len = max(subquery_lens)
208
209
210
        input_tokens = _make_tensor_with_pad(input_tokens,
                                             max_prompt_len,
                                             pad=0,
211
212
                                             dtype=torch.long,
                                             device=self.device)
213
214
215
        input_positions = _make_tensor_with_pad(input_positions,
                                                max_prompt_len,
                                                pad=0,
216
217
                                                dtype=torch.long,
                                                device=self.device)
218
219
220
        slot_mapping = _make_tensor_with_pad(slot_mapping,
                                             max_prompt_len,
                                             pad=_PAD_SLOT_ID,
221
222
                                             dtype=torch.long,
                                             device=self.device)
223
224
225
226
        lora_index_mapping = [
            _pad_to_max(mapping, max_prompt_len, pad=0)
            for mapping in lora_index_mapping
        ]
227
228
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
229
                                           device=self.device)
230
231
232
233
234
235
236
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
        block_tables = _make_tensor_with_pad(
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
237
            device=self.device,
238
239
240
241
242
        )
        start_loc_tensor = torch.arange(0,
                                        len(prompt_lens) * max_prompt_len,
                                        max_prompt_len,
                                        dtype=torch.long,
243
                                        device=self.device)
244
245
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
246
                                          device=self.device)
247
248

        input_metadata = InputMetadata(
249
            is_prompt=True,
250
            slot_mapping=slot_mapping,
251
252
253
            prompt_lens=prompt_lens_tensor,
            max_seq_len=max_prompt_len,
            start_loc=start_loc_tensor,
254
            max_context_len=None,
255
256
            context_lens=context_lens_tensor,
            block_tables=block_tables,
257
            use_cuda_graph=False,
258
            kv_cache_dtype=self.kv_cache_dtype,
259
        )
260
        return (input_tokens, input_positions, input_metadata, prompt_lens,
261
262
                subquery_lens, lora_index_mapping, lora_prompt_mapping,
                lora_requests)
263
264
265
266

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
267
268
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
               Set[LoRARequest]]:
269
270
271
272
273
274
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
275
276
277
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
278
279
280
281
282

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())
283
284
285
286
287
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

288
289
290
291
292
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

293
294
                seq_len = seq_data.get_len()
                position = seq_len - 1
295
296
                input_positions.append([position])

297
298
299
300
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

301
302
303
304
305
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append([slot])
306
307
                lora_index_mapping.append([lora_id])
                lora_prompt_mapping.append(lora_id)
308
309
310
311
312
313
314

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            # Pad the input tokens, positions, and slot mapping to match the
            # batch size of the captured graph.
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
                input_tokens.append([])
                input_positions.append([])
                slot_mapping.append([])
                context_lens.append(1)
                block_tables.append([])
            batch_size = graph_batch_size

334
335
336
        input_tokens = _make_tensor_with_pad(input_tokens,
                                             max_len=1,
                                             pad=0,
337
                                             dtype=torch.long,
338
                                             device=self.device)
339
340
341
        input_positions = _make_tensor_with_pad(input_positions,
                                                max_len=1,
                                                pad=0,
342
                                                dtype=torch.long,
343
                                                device=self.device)
344
345
346
        slot_mapping = _make_tensor_with_pad(slot_mapping,
                                             max_len=1,
                                             pad=_PAD_SLOT_ID,
347
                                             dtype=torch.long,
348
                                             device=self.device)
349
350
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
351
                                    device=self.device)
352
353
354
355
356
357
358
359

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
360
            block_tables = torch.tensor(input_block_tables, device=self.device)
361
        else:
362
363
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
364
365
            block_tables = _make_tensor_with_pad(
                block_tables,
366
                max_len=max_block_table_len,
367
368
                pad=0,
                dtype=torch.int,
369
                device=self.device,
370
            )
371

372
373
374
375
        lora_index_mapping = [
            _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
        ]

376
        input_metadata = InputMetadata(
377
            is_prompt=False,
378
            slot_mapping=slot_mapping,
379
380
381
            prompt_lens=None,
            max_seq_len=None,
            start_loc=None,
382
383
384
            max_context_len=max_context_len,
            context_lens=context_lens,
            block_tables=block_tables,
385
            use_cuda_graph=use_captured_graph,
386
            kv_cache_dtype=self.kv_cache_dtype,
387
        )
388
389
        return (input_tokens, input_positions, input_metadata,
                lora_index_mapping, lora_prompt_mapping, lora_requests)
390
391
392
393
394

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
395
        subquery_lens: Optional[List[int]],
396
397
398
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
399
        generators: List[torch.Generator] = []
400
401
402
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
403
        pin_memory = not self.in_wsl and not self.device_config.is_neuron
404

405
        max_subquery_len = max(subquery_lens) if subquery_lens else 1
406
407
408
409
410
411
412
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
413
414
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
415
416
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
417
                    categorized_sample_indices_start_idx += subquery_len - 1
418
419
420
421
422
423
424
425
426

                categorized_sample_indices[
                    sampling_params.sampling_type].append(
                        categorized_sample_indices_start_idx)
                categorized_sample_indices_start_idx += 1

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
427
                              selected_token_start_idx + subquery_len - 1))
428
                selected_token_indices.append(selected_token_start_idx +
429
430
                                              subquery_len - 1)
                selected_token_start_idx += max_subquery_len
Nick Hill's avatar
Nick Hill committed
431
432
433
434

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
                        device="cuda").manual_seed(sampling_params.seed)
435
436
437
438
439
440
441
442
443
444
445
446
447
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
                        range(categorized_sample_indices_start_idx,
                              categorized_sample_indices_start_idx + num_seqs))
                categorized_sample_indices_start_idx += num_seqs

Nick Hill's avatar
Nick Hill committed
448
449
450
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

451
452
        selected_token_indices = _async_h2d(selected_token_indices,
                                            dtype=torch.long,
453
                                            target_device=self.device,
454
                                            pin_memory=pin_memory)
455
        categorized_sample_indices = {
456
457
458
            t: _async_h2d(seq_ids,
                          dtype=torch.int,
                          target_device=self.device,
459
                          pin_memory=pin_memory)
460
461
462
463
464
465
466
467
468
469
470
471
472
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
473
            generators=generators,
474
475
476
        )
        return sampling_metadata

477
478
479
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
480
481
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
               Set[int], LoRAMapping]:
482
483
484
485
486
487
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
488
                (input_tokens, input_positions, input_metadata, prompt_lens,
489
490
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_prompt(seq_group_metadata_list)
491
            else:
492
493
494
                (input_tokens, input_positions, input_metadata,
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
495
                prompt_lens = []
496
                subquery_lens = None
497
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
498
499
                                                     prompt_lens,
                                                     subquery_lens)
500

501
502
503
504
505
506
507
508
509
510
511
            if self.lora_config:
                flat_lora_index_mapping = [
                    item for sublist in lora_index_mapping for item in sublist
                ]
                lora_mapping = LoRAMapping(
                    flat_lora_index_mapping,
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

512
513
514
515
516
517
518
519
520
521
522
523
524
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "is_prompt": input_metadata.is_prompt,
                "slot_mapping": input_metadata.slot_mapping,
                "prompt_lens": input_metadata.prompt_lens,
                "max_seq_len": input_metadata.max_seq_len,
                "start_loc": input_metadata.start_loc,
                "max_context_len": input_metadata.max_context_len,
                "context_lens": input_metadata.context_lens,
                "block_tables": input_metadata.block_tables,
                "use_cuda_graph": input_metadata.use_cuda_graph,
525
                "kv_cache_dtype": input_metadata.kv_cache_dtype,
526
527
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
528
529
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
530
            }
531
            broadcast_tensor_dict(metadata_dict, src=0)
532
        else:
533
534
535
            metadata_dict = broadcast_tensor_dict(src=0)
            input_tokens = metadata_dict["input_tokens"]
            input_positions = metadata_dict["input_positions"]
536
537
            lora_mapping = metadata_dict["lora_mapping"]
            lora_requests = metadata_dict["lora_requests"]
538
            input_metadata = InputMetadata(
539
540
541
542
543
544
545
546
547
                is_prompt=metadata_dict["is_prompt"],
                slot_mapping=metadata_dict["slot_mapping"],
                prompt_lens=metadata_dict["prompt_lens"],
                max_seq_len=metadata_dict["max_seq_len"],
                start_loc=metadata_dict["start_loc"],
                max_context_len=metadata_dict["max_context_len"],
                context_lens=metadata_dict["context_lens"],
                block_tables=metadata_dict["block_tables"],
                use_cuda_graph=metadata_dict["use_cuda_graph"],
548
                kv_cache_dtype=metadata_dict["kv_cache_dtype"],
549
550
551
552
553
            )
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
554
                selected_token_indices=metadata_dict["selected_token_indices"],
555
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
556
                generators=None,
557
558
559
                perform_sampling=False,
            )

560
561
        return (input_tokens, input_positions, input_metadata,
                sampling_metadata, lora_requests, lora_mapping)
562

563
564
565
    @torch.inference_mode()
    def execute_model(
        self,
566
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
567
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
568
    ) -> Optional[SamplerOutput]:
569
570
571
        (input_tokens, input_positions, input_metadata, sampling_metadata,
         lora_requests,
         lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
572
573
574
575

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

576
        # Execute the model.
577
578
579
580
581
582
        if input_metadata.use_cuda_graph:
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
        hidden_states = model_executable(
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            input_metadata=input_metadata,
        )

        # Sample the next token.
        output = self.model.sample(
            hidden_states=hidden_states,
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
        vocab_size = self.model_config.get_vocab_size()
        sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

626
627
628
629
630
631
632
633
634
635
636
637
638
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
639
640
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
641
642
643
644
645
646
647
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        kv_caches = [(None, None)] * num_layers
        self.execute_model(seqs, kv_caches)
648
        torch.cuda.synchronize()
649
650
        return

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

677
678
    @torch.inference_mode()
    def capture_model(self, kv_caches: List[KVCache]) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
679
680
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
681
        self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
682

683
684
685
686
687
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
688
689
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
690
691
692
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
693
694
695
696
697
698
699
700
701
702
703
704
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
        input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, 1,
                                      dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

705
706
707
708
709
710
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
711
712
713
714
715
716
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
        # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
        # to PyTorch or CuPy NCCL if it is disabled or not supported.
717
        with custom_all_reduce.capture():
718
719
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
720
721
722
723
724
725
726
727
728
729
730
731
            for batch_size in reversed(batch_size_capture_list):
                # Create dummy input_metadata.
                input_metadata = InputMetadata(
                    is_prompt=False,
                    slot_mapping=slot_mapping[:batch_size],
                    prompt_lens=None,
                    max_seq_len=None,
                    start_loc=None,
                    max_context_len=self.max_context_len_to_capture,
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
732
                    kv_cache_dtype=self.kv_cache_dtype,
733
                )
734

735
736
737
738
739
740
741
742
743
744
745
746
747
748
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
                    input_metadata,
                    memory_pool=self.graph_memory_pool,
749
                )
750
751
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
752
753
754
755
756
757

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
758
759
760
761
762
763
764
765
    def __del__(self) -> None:
        # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
        self.graph_runners.clear()
        self.cupy_nccl_backend = None

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        memory_pool,
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
787
        with _maybe_cupy_nccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
788
            self.model(
789
790
791
792
793
794
795
                input_ids,
                positions,
                kv_caches,
                input_metadata,
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
796
797
798
799
800
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
801
            with _maybe_cupy_nccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
802
803
804
805
806
807
808
809
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
                    input_metadata,
                )
        torch.cuda.synchronize()

810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
            "slot_mapping": input_metadata.slot_mapping,
            "context_lens": input_metadata.context_lens,
            "block_tables": input_metadata.block_tables,
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
833
834
835
836
837
838
839
840
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
        self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
                                                 non_blocking=True)
        self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
                                                 non_blocking=True)
        self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
                                                 non_blocking=True)
841
842
843
844
845
846
847
848
849
850

        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

851

852
853
854
855
856
857
858
859
860
@contextlib.contextmanager
def _maybe_cupy_nccl():
    if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
        with with_cupy_nccl_for_all_reduce():
            yield
    else:
        yield


861
862
863
864
865
866
867
868
869
870
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
    assert len(x) <= max_len
    return x + [pad] * (max_len - len(x))


def _make_tensor_with_pad(
    x: List[List[int]],
    max_len: int,
    pad: int,
    dtype: torch.dtype,
871
    device: Optional[Union[str, torch.device]],
872
873
) -> torch.Tensor:
    padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
874
    return torch.tensor(padded_x, dtype=dtype, device=device)
875
876
877
878
879
880
881
882
883


def _get_graph_batch_size(batch_size: int) -> int:
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
        return (batch_size + 7) // 8 * 8
884
885


886
887
888
889
890
891
892
893
def _async_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)