"vscode:/vscode.git/clone" did not exist on "cf16342d435f5d6cbbeaf076ed4546bda0f89a20"
model_runner.py 39.1 KB
Newer Older
1
import contextlib
2
import time
3
from typing import Dict, List, Optional, Tuple, Set
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
                         SchedulerConfig)
11
from vllm.logger import init_logger
12
13
from vllm.model_executor import InputMetadata, SamplingMetadata
from vllm.model_executor.model_loader import get_model
14
from vllm.model_executor.parallel_utils import cupy_utils
15
from vllm.model_executor.parallel_utils.communication_op import (
16
    broadcast_tensor_dict)
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
from vllm.model_executor.parallel_utils.parallel_state import (
    with_cupy_nccl_for_all_reduce)
19
from vllm.model_executor.parallel_utils import custom_all_reduce
20
21
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
22
23
24
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
25
26
27
from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
                        is_pin_memory_available, make_tensor_with_pad,
                        maybe_expand_dim)
28
29
30

logger = init_logger(__name__)

31
KVCache = Tuple[torch.Tensor, torch.Tensor]
32
_PAD_SLOT_ID = -1
33
LORA_WARMUP_RANK = 8
34
35
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
38
39
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
40
41
42
43
44
45
46
47
48


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
49
        device_config: DeviceConfig,
50
        lora_config: Optional[LoRAConfig],
51
        kv_cache_dtype: Optional[str] = "auto",
52
        is_driver_worker: bool = False,
53
54
55
56
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
57
        self.lora_config = lora_config
58
        self.is_driver_worker = is_driver_worker
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
64
65
66
67
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

68
69
        self.model = None
        self.block_size = None  # Set after initial profiling.
70
        self.lora_manager = None
71

72
73
74
75
76
77
78
79
80
81
82
83
84
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
85
        self.pin_memory = is_pin_memory_available()
86
        self.kv_cache_dtype = kv_cache_dtype
87

88
    def load_model(self) -> None:
89
        with CudaMemoryProfiler() as m:
90
91
92
93
94
95
96
            self.model = get_model(self.model_config,
                                   self.device_config,
                                   lora_config=self.lora_config,
                                   parallel_config=self.parallel_config,
                                   scheduler_config=self.scheduler_config)

        self.model_memory_usage = m.consumed_memory
97
98
        logger.info(f"Loading model weights took "
                    f"{self.model_memory_usage / float(2**30):.4f} GB")
99
100

        if self.lora_config:
101
102
103
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
104
105
106
107
108
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
109
110
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
111
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
112
113
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
114
            self.model = self.lora_manager.create_lora_manager(self.model)
115
116
117
118

    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

119
        self.graph_block_tables = np.zeros(
120
121
122
123
124
125
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
        return (self.max_context_len_to_capture + block_size - 1) // block_size
126

127
128
129
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
130
131
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
               List[int], List[int], Set[LoRARequest]]:
132
        assert len(seq_group_metadata_list) > 0
133
134
135
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
136
137
138
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
139
140

        prompt_lens: List[int] = []
141
142
143
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
144
145
146
147
148
149
150
151
152
153
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)
154
155
156
157
158
159
160
161
162
163
            computed_len = 0

            # NOTE: This only works for oooooooxxx style attention.
            computed_block_nums = seq_group_metadata.computed_block_nums
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
                computed_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[computed_len:]
                prefix_block_tables.append(computed_block_nums)
164
                context_len = computed_len
165
166
            else:
                prefix_block_tables.append([])
167
                context_len = 0
168
            # actual prompt lens
169
            context_lens.append(context_len)
170
            subquery_lens.append(prompt_len - computed_len)
171

172
            input_tokens.extend(prompt_tokens)
173
174
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
175
            input_positions.extend(
176
                list(range(computed_len, computed_len + len(prompt_tokens))))
177

178
179
180
181
182
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

183
            lora_index_mapping += [lora_id] * (prompt_len - computed_len)
184
185
            lora_prompt_mapping.extend(
                [lora_id] *
186
                (prompt_len - computed_len
187
188
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

189
190
191
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
192
                slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
193
194
195
196
197
198
199
200
201
202
203
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
204
                assert computed_len == 0, (
205
206
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
207
                start_idx = max(0, prompt_len - self.sliding_window)
208
            for i in range(computed_len, prompt_len):
209
                if i < start_idx:
210
                    slot_mapping.append(_PAD_SLOT_ID)
211
212
213
214
215
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
                slot_mapping.append(slot)

        max_subquery_len = max(subquery_lens)
        max_seq_len = max(prompt_lens)
        num_prompt_tokens = len(input_tokens)
        assert max_subquery_len > 0

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
        lora_index_mapping = lora_index_mapping

234
235
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
236
                                           device=self.device)
237
238
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
239
        block_tables = make_tensor_with_pad(
240
241
242
243
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
244
            device=self.device,
245
        )
246
247
248
249
250
251
252
253
254
255

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
        subquery_lens_tensor = torch.tensor(subquery_lens,
                                            dtype=torch.long,
                                            device=self.device)
        subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
                                         dtype=torch.int32,
                                         device=self.device)

256
257
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
258
                                          device=self.device)
259
260
261
262
263
264
265
266
267
268
269
270
271
        seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(subquery_lens_tensor,
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

        torch.cumsum(prompt_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
272
273

        input_metadata = InputMetadata(
274
            is_prompt=True,
275
            slot_mapping=slot_mapping,
276
277
278
279
280
            prompt_lens=prompt_lens,
            prompt_lens_tensor=prompt_lens_tensor,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=0,
            max_subquery_len=max_subquery_len,
281
            max_context_len=None,
282
283
284
            max_seq_len=max_seq_len,
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
285
286
            context_lens=context_lens_tensor,
            block_tables=block_tables,
287
            use_cuda_graph=False,
288
            kv_cache_dtype=self.kv_cache_dtype,
289
        )
290
        return (input_tokens, input_positions, input_metadata, prompt_lens,
291
292
                subquery_lens, lora_index_mapping, lora_prompt_mapping,
                lora_requests)
293
294
295
296

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
297
298
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
               Set[LoRARequest]]:
299
        assert len(seq_group_metadata_list) > 0
300
301
302
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
303
304
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
305
306
307
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
308
309
310
311
312

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())
313
314
315
316
317
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

318
319
320
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
321
                input_tokens.append(generation_token)
322

323
324
                seq_len = seq_data.get_len()
                position = seq_len - 1
325
                input_positions.append(position)
326

327
328
329
330
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

331
332
333
334
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
335
336
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
337
                lora_prompt_mapping.append(lora_id)
338
339
340
341
342
343
344

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

345
346
347
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
348
349
350
351
352
353
354
355
356
357
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
358
359
360
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
361
362
                context_lens.append(1)
                block_tables.append([])
363
                lora_index_mapping.append(0)
364
365
            batch_size = graph_batch_size

366
367
368
369
370
371
372
373
374
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
375
376
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
377
                                    device=self.device)
378
379

        if use_captured_graph:
380
381
382
383
384
385
            # When using cuda-graph all these tensors should be
            # padded.
            assert context_lens.shape[0] == input_tokens.shape[0]
            assert context_lens.shape[0] == input_positions.shape[0]
            assert context_lens.shape[0] == slot_mapping.shape[0]

386
387
388
389
390
391
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
392
            block_tables = torch.tensor(input_block_tables, device=self.device)
393
        else:
394
395
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
396
            block_tables = make_tensor_with_pad(
397
                block_tables,
398
                max_len=max_block_table_len,
399
400
                pad=0,
                dtype=torch.int,
401
                device=self.device,
402
            )
403
404

        input_metadata = InputMetadata(
405
            is_prompt=False,
406
            slot_mapping=slot_mapping,
407
            prompt_lens=None,
408
409
410
411
            prompt_lens_tensor=None,
            num_prompt_tokens=0,
            num_generation_tokens=len(input_tokens),
            max_subquery_len=None,
412
            max_context_len=max_context_len,
413
414
415
            max_seq_len=None,
            subquery_start_loc=None,
            seq_start_loc=None,
416
417
            context_lens=context_lens,
            block_tables=block_tables,
418
            use_cuda_graph=use_captured_graph,
419
            kv_cache_dtype=self.kv_cache_dtype,
420
        )
421
422
        return (input_tokens, input_positions, input_metadata,
                lora_index_mapping, lora_prompt_mapping, lora_requests)
423
424
425
426
427

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
428
        subquery_lens: Optional[List[int]],
429
430
431
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
432
        generators: List[torch.Generator] = []
433
434
435
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
436
        categorized_sampled_token_indices_start_idx = 0
437
438
439
440
441
442
443
444

        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
445
446
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
447
448
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
449
                    categorized_sample_indices_start_idx += subquery_len - 1
450
451

                categorized_sample_indices[
452
453
454
455
                    sampling_params.sampling_type].append([
                        categorized_sample_indices_start_idx,
                        categorized_sampled_token_indices_start_idx
                    ])
456
                categorized_sample_indices_start_idx += 1
457
                categorized_sampled_token_indices_start_idx += 1
458
459
460
461

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
462
                              selected_token_start_idx + subquery_len - 1))
463
                selected_token_indices.append(selected_token_start_idx +
464
                                              subquery_len - 1)
465
                selected_token_start_idx += subquery_len
Nick Hill's avatar
Nick Hill committed
466
467
468

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
469
                        device=self.device).manual_seed(sampling_params.seed)
470
471
472
473
474
475
476
477
478
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
479
480
481
482
483
484
485
486
487
                        zip(
                            range(
                                categorized_sample_indices_start_idx,
                                categorized_sample_indices_start_idx +
                                num_seqs),
                            range(
                                categorized_sampled_token_indices_start_idx,
                                categorized_sampled_token_indices_start_idx +
                                num_seqs)))
488
                categorized_sample_indices_start_idx += num_seqs
489
                categorized_sampled_token_indices_start_idx += num_seqs
490

Nick Hill's avatar
Nick Hill committed
491
492
493
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

494
495
496
497
        selected_token_indices = async_tensor_h2d(selected_token_indices,
                                                  dtype=torch.long,
                                                  target_device=self.device,
                                                  pin_memory=self.pin_memory)
498

499
        categorized_sample_indices = {
500
501
502
503
504
            t: maybe_expand_dim(
                async_tensor_h2d(seq_ids,
                                 dtype=torch.int,
                                 target_device=self.device,
                                 pin_memory=self.pin_memory), 2, 2)
505
506
507
508
509
510
511
512
513
514
515
516
517
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
518
            generators=generators,
519
520
521
        )
        return sampling_metadata

522
523
524
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
525
526
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
               Set[int], LoRAMapping]:
527
528
529
530
531
532
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
533
                (input_tokens, input_positions, input_metadata, prompt_lens,
534
535
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_prompt(seq_group_metadata_list)
536
            else:
537
538
539
                (input_tokens, input_positions, input_metadata,
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
540
                prompt_lens = []
541
                subquery_lens = None
542
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
543
544
                                                     prompt_lens,
                                                     subquery_lens)
545

546
547
            if self.lora_config:
                lora_mapping = LoRAMapping(
548
                    lora_index_mapping,
549
550
551
552
553
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

554
555
556
557
558
559
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
560
561
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
562
            }
563
            metadata_dict.update(input_metadata.asdict_zerocopy())
564
            broadcast_tensor_dict(metadata_dict, src=0)
565
        else:
566
            metadata_dict = broadcast_tensor_dict(src=0)
567
568
569
570
571
572
573
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
            input_metadata = InputMetadata(**metadata_dict)
574
575
576
577
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
578
                selected_token_indices=selected_token_indices,
579
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
580
                generators=None,
581
582
583
                perform_sampling=False,
            )

584
585
        return (input_tokens, input_positions, input_metadata,
                sampling_metadata, lora_requests, lora_mapping)
586

587
588
589
    @torch.inference_mode()
    def execute_model(
        self,
590
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
591
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
592
    ) -> Optional[SamplerOutput]:
593
594
595
        (input_tokens, input_positions, input_metadata, sampling_metadata,
         lora_requests,
         lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
596
597
598
599

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

600
        # Execute the model.
601
602
603
604
605
606
        if input_metadata.use_cuda_graph:
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
        hidden_states = model_executable(
607
608
609
610
611
612
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            input_metadata=input_metadata,
        )

613
614
615
616
617
618
619
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

620
621
        # Sample the next token.
        output = self.model.sample(
622
            logits=logits,
623
624
625
626
627
628
629
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
630
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
631
632
633
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

656
657
658
659
660
661
662
663
664
665
666
667
668
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
669
670
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
671
672
673
674
675
676
677
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        kv_caches = [(None, None)] * num_layers
        self.execute_model(seqs, kv_caches)
678
        torch.cuda.synchronize()
679
680
        return

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

707
708
    @torch.inference_mode()
    def capture_model(self, kv_caches: List[KVCache]) -> None:
709
710
711
712
713
714
715
716
717
718
719
720
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
721
722
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
723
        self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
724

725
726
727
728
729
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
730
731
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
732
733
734
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
735
736
737
738
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
739
740
741
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
742
743
744
745
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

746
747
748
749
750
751
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
752
753
754
755
756
757
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
        # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
        # to PyTorch or CuPy NCCL if it is disabled or not supported.
758
        with custom_all_reduce.capture():
759
760
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
761
762
763
764
765
766
            for batch_size in reversed(batch_size_capture_list):
                # Create dummy input_metadata.
                input_metadata = InputMetadata(
                    is_prompt=False,
                    slot_mapping=slot_mapping[:batch_size],
                    prompt_lens=None,
767
768
769
770
                    prompt_lens_tensor=None,
                    num_prompt_tokens=0,
                    num_generation_tokens=batch_size,
                    max_subquery_len=None,
771
                    max_context_len=self.max_context_len_to_capture,
772
773
774
                    max_seq_len=None,
                    subquery_start_loc=None,
                    seq_start_loc=None,
775
776
777
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
778
                    kv_cache_dtype=self.kv_cache_dtype,
779
                )
780

781
782
783
784
785
786
787
788
789
790
791
792
793
794
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
                    input_metadata,
                    memory_pool=self.graph_memory_pool,
795
                )
796
797
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
798
799
800
801
802
803

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
804
805
806
807
808
809
810
811
    def __del__(self) -> None:
        # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
        self.graph_runners.clear()
        self.cupy_nccl_backend = None

812
813
814
815
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        memory_pool,
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
837
        with _maybe_cupy_nccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
838
            self.model(
839
840
841
842
843
844
845
                input_ids,
                positions,
                kv_caches,
                input_metadata,
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
846
847
848
849
850
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
851
            with _maybe_cupy_nccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
852
853
854
855
856
857
858
859
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
                    input_metadata,
                )
        torch.cuda.synchronize()

860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
            "slot_mapping": input_metadata.slot_mapping,
            "context_lens": input_metadata.context_lens,
            "block_tables": input_metadata.block_tables,
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
883
884
885
886
887
888
889
890
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
        self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
                                                 non_blocking=True)
        self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
                                                 non_blocking=True)
        self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
                                                 non_blocking=True)
891
892
893
894
895
896
897
898
899
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

900

901
902
903
904
905
906
907
908
909
@contextlib.contextmanager
def _maybe_cupy_nccl():
    if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
        with with_cupy_nccl_for_all_reduce():
            yield
    else:
        yield


910
def _get_graph_batch_size(batch_size: int) -> int:
911
912
913
914
915
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
916
917
918
919
920
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
921
922
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)