model_runner.py 39.2 KB
Newer Older
1
import contextlib
2
import time
3
from typing import Dict, List, Optional, Set, Tuple
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

9
from vllm.attention import AttentionMetadata, get_attn_backend
10
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
11
                         SchedulerConfig)
12
from vllm.logger import init_logger
13
14
15
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
16
from vllm.model_executor import SamplingMetadata
17
from vllm.model_executor.model_loader import get_model
18
from vllm.model_executor.parallel_utils import cupy_utils, custom_all_reduce
19
from vllm.model_executor.parallel_utils.communication_op import (
20
    broadcast_tensor_dict)
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
from vllm.model_executor.parallel_utils.parallel_state import (
    with_cupy_nccl_for_all_reduce)
23
24
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
25
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d,
26
27
                        is_pin_memory_available, make_tensor_with_pad,
                        maybe_expand_dim)
28
29
30
31

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
32
LORA_WARMUP_RANK = 8
33
34
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
35
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
36
37
38
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
39
40
41
42
43
44
45
46
47


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
48
        device_config: DeviceConfig,
49
        lora_config: Optional[LoRAConfig],
50
        kv_cache_dtype: Optional[str] = "auto",
51
        is_driver_worker: bool = False,
52
53
54
55
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
56
        self.lora_config = lora_config
57
        self.is_driver_worker = is_driver_worker
58

Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
63
64
65
66
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device

67
68
        self.model = None
        self.block_size = None  # Set after initial profiling.
69
        self.lora_manager = None
70

71
72
73
74
75
76
77
78
79
80
81
82
83
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
84
        self.pin_memory = is_pin_memory_available()
85
        self.kv_cache_dtype = kv_cache_dtype
86

87
88
89
        self.attn_backend = get_attn_backend(
            self.model_config.dtype if model_config is not None else None)

90
    def load_model(self) -> None:
91
        with CudaMemoryProfiler() as m:
92
93
94
95
96
97
98
            self.model = get_model(self.model_config,
                                   self.device_config,
                                   lora_config=self.lora_config,
                                   parallel_config=self.parallel_config,
                                   scheduler_config=self.scheduler_config)

        self.model_memory_usage = m.consumed_memory
99
100
        logger.info(f"Loading model weights took "
                    f"{self.model_memory_usage / float(2**30):.4f} GB")
101
102

        if self.lora_config:
103
104
105
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
106
107
108
109
110
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
111
112
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
113
                self.scheduler_config.max_num_batched_tokens, self.vocab_size,
Terry's avatar
Terry committed
114
115
                self.lora_config, self.device, self.model.embedding_modules,
                self.model.embedding_padding_modules)
116
            self.model = self.lora_manager.create_lora_manager(self.model)
117
118
119
120

    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

121
        self.graph_block_tables = np.zeros(
122
123
124
125
126
127
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)

    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
        return (self.max_context_len_to_capture + block_size - 1) // block_size
128

129
130
131
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
132
133
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               List[int], List[int], List[int], Set[LoRARequest]]:
134
        assert len(seq_group_metadata_list) > 0
135
136
137
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
138
139
140
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
141
142

        prompt_lens: List[int] = []
143
144
145
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
146
147
148
149
150
151
152
153
154
155
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)
156
157
158
159
160
161
162
163
164
165
            computed_len = 0

            # NOTE: This only works for oooooooxxx style attention.
            computed_block_nums = seq_group_metadata.computed_block_nums
            if computed_block_nums is not None and len(
                    computed_block_nums) > 0 and self.sliding_window is None:
                # Prefix is not supported with sliding_window
                computed_len = len(computed_block_nums) * self.block_size
                prompt_tokens = prompt_tokens[computed_len:]
                prefix_block_tables.append(computed_block_nums)
166
                context_len = computed_len
167
168
            else:
                prefix_block_tables.append([])
169
                context_len = 0
170
            # actual prompt lens
171
            context_lens.append(context_len)
172
            subquery_lens.append(prompt_len - computed_len)
173

174
            input_tokens.extend(prompt_tokens)
175
176
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
177
            input_positions.extend(
178
                list(range(computed_len, computed_len + len(prompt_tokens))))
179

180
181
182
183
184
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

185
            lora_index_mapping += [lora_id] * (prompt_len - computed_len)
186
187
            lora_prompt_mapping.extend(
                [lora_id] *
188
                (prompt_len - computed_len
189
190
                 if seq_group_metadata.sampling_params.prompt_logprobs else 1))

191
192
193
            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
194
                slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
195
196
197
198
199
200
201
202
203
204
205
                continue

            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
206
                assert computed_len == 0, (
207
208
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
209
                start_idx = max(0, prompt_len - self.sliding_window)
210
            for i in range(computed_len, prompt_len):
211
                if i < start_idx:
212
                    slot_mapping.append(_PAD_SLOT_ID)
213
214
215
216
217
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
218
219
220
                slot_mapping.append(slot)

        max_subquery_len = max(subquery_lens)
221
        max_prompt_len = max(prompt_lens)
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        num_prompt_tokens = len(input_tokens)
        assert max_subquery_len > 0

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
        lora_index_mapping = lora_index_mapping

236
237
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
238
                                           device=self.device)
239
240
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
241
        block_tables = make_tensor_with_pad(
242
243
244
245
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
246
            device=self.device,
247
        )
248
249
250
251
252
253
254
255
256
257

        # Query length can be shorter than key (i.e., prompt) when prefill
        # is chunked or prefix cached.
        subquery_lens_tensor = torch.tensor(subquery_lens,
                                            dtype=torch.long,
                                            device=self.device)
        subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
                                         dtype=torch.int32,
                                         device=self.device)

258
259
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
260
                                          device=self.device)
261
262
263
264
265
266
267
268
269
270
271
272
273
        seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(subquery_lens_tensor,
                     dim=0,
                     dtype=subquery_start_loc.dtype,
                     out=subquery_start_loc[1:])

        torch.cumsum(prompt_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
274

275
        attn_metadata = self.attn_backend.make_metadata(
276
            is_prompt=True,
277
            slot_mapping=slot_mapping,
278
279
280
281
282
            prompt_lens=prompt_lens,
            prompt_lens_tensor=prompt_lens_tensor,
            num_prompt_tokens=num_prompt_tokens,
            num_generation_tokens=0,
            max_subquery_len=max_subquery_len,
283
            max_context_len=None,
284
            max_prompt_len=max_prompt_len,
285
286
            subquery_start_loc=subquery_start_loc,
            seq_start_loc=seq_start_loc,
287
288
            context_lens=context_lens_tensor,
            block_tables=block_tables,
289
            use_cuda_graph=False,
290
            kv_cache_dtype=self.kv_cache_dtype,
291
        )
292
        return (input_tokens, input_positions, attn_metadata, prompt_lens,
293
294
                subquery_lens, lora_index_mapping, lora_prompt_mapping,
                lora_requests)
295
296
297
298

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
299
300
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
               List[int], Set[LoRARequest]]:
301
        assert len(seq_group_metadata_list) > 0
302
303
304
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
305
306
        context_lens: List[int] = []
        block_tables: List[List[int]] = []
307
308
309
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
310
311
312
313
314

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())
315
316
317
318
319
            lora_id = seq_group_metadata.lora_int_id

            if lora_id > 0:
                lora_requests.add(seq_group_metadata.lora_request)

320
321
322
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
323
                input_tokens.append(generation_token)
324

325
326
                seq_len = seq_data.get_len()
                position = seq_len - 1
327
                input_positions.append(position)
328

329
330
331
332
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

333
334
335
336
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
337
338
                slot_mapping.append(slot)
                lora_index_mapping.append(lora_id)
339
                lora_prompt_mapping.append(lora_id)
340
341
342
343
344
345
346

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

347
348
349
        # vLLM uses cuda graph only for decoding requests.
        # See `capture_model` API for more details.
        # For decoding requests, batch_size == input_tokens.
350
351
352
353
354
355
356
357
358
359
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
360
361
362
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
363
364
                context_lens.append(1)
                block_tables.append([])
365
                lora_index_mapping.append(0)
366
367
            batch_size = graph_batch_size

368
369
370
371
372
373
374
375
376
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
377
378
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
379
                                    device=self.device)
380
381

        if use_captured_graph:
382
383
384
385
386
387
            # When using cuda-graph all these tensors should be
            # padded.
            assert context_lens.shape[0] == input_tokens.shape[0]
            assert context_lens.shape[0] == input_positions.shape[0]
            assert context_lens.shape[0] == slot_mapping.shape[0]

388
389
390
391
392
393
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
394
            block_tables = torch.tensor(input_block_tables, device=self.device)
395
        else:
396
397
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
398
            block_tables = make_tensor_with_pad(
399
                block_tables,
400
                max_len=max_block_table_len,
401
402
                pad=0,
                dtype=torch.int,
403
                device=self.device,
404
            )
405

406
        attn_metadata = self.attn_backend.make_metadata(
407
            is_prompt=False,
408
            slot_mapping=slot_mapping,
409
            prompt_lens=None,
410
411
412
413
            prompt_lens_tensor=None,
            num_prompt_tokens=0,
            num_generation_tokens=len(input_tokens),
            max_subquery_len=None,
414
            max_context_len=max_context_len,
415
            max_prompt_len=None,
416
417
            subquery_start_loc=None,
            seq_start_loc=None,
418
419
            context_lens=context_lens,
            block_tables=block_tables,
420
            use_cuda_graph=use_captured_graph,
421
            kv_cache_dtype=self.kv_cache_dtype,
422
        )
423
        return (input_tokens, input_positions, attn_metadata,
424
                lora_index_mapping, lora_prompt_mapping, lora_requests)
425
426
427
428
429

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
430
        subquery_lens: Optional[List[int]],
431
432
433
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
Nick Hill's avatar
Nick Hill committed
434
        generators: List[torch.Generator] = []
435
436
437
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0
438
        categorized_sampled_token_indices_start_idx = 0
439
440
441
442
443
444
445
446

        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
447
448
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
449
450
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
451
                    categorized_sample_indices_start_idx += subquery_len - 1
452
453

                categorized_sample_indices[
454
455
456
457
                    sampling_params.sampling_type].append([
                        categorized_sample_indices_start_idx,
                        categorized_sampled_token_indices_start_idx
                    ])
458
                categorized_sample_indices_start_idx += 1
459
                categorized_sampled_token_indices_start_idx += 1
460
461
462
463

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
464
                              selected_token_start_idx + subquery_len - 1))
465
                selected_token_indices.append(selected_token_start_idx +
466
                                              subquery_len - 1)
467
                selected_token_start_idx += subquery_len
Nick Hill's avatar
Nick Hill committed
468
469
470

                if sampling_params.seed is not None:
                    seq_group_metadata.state.generator = torch.Generator(
471
                        device=self.device).manual_seed(sampling_params.seed)
472
473
474
475
476
477
478
479
480
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
481
482
483
484
485
486
487
488
489
                        zip(
                            range(
                                categorized_sample_indices_start_idx,
                                categorized_sample_indices_start_idx +
                                num_seqs),
                            range(
                                categorized_sampled_token_indices_start_idx,
                                categorized_sampled_token_indices_start_idx +
                                num_seqs)))
490
                categorized_sample_indices_start_idx += num_seqs
491
                categorized_sampled_token_indices_start_idx += num_seqs
492

Nick Hill's avatar
Nick Hill committed
493
494
495
            if sampling_params.seed is not None:
                generators.append(seq_group_metadata.state.generator)

496
497
498
499
        selected_token_indices = async_tensor_h2d(selected_token_indices,
                                                  dtype=torch.long,
                                                  target_device=self.device,
                                                  pin_memory=self.pin_memory)
500

501
        categorized_sample_indices = {
502
503
504
505
506
            t: maybe_expand_dim(
                async_tensor_h2d(seq_ids,
                                 dtype=torch.int,
                                 target_device=self.device,
                                 pin_memory=self.pin_memory), 2, 2)
507
508
509
510
511
512
513
514
515
516
517
518
519
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
Nick Hill's avatar
Nick Hill committed
520
            generators=generators,
521
522
523
        )
        return sampling_metadata

524
525
526
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
527
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
528
               Set[int], LoRAMapping]:
529
530
531
532
533
534
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
535
                (input_tokens, input_positions, attn_metadata, prompt_lens,
536
537
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_prompt(seq_group_metadata_list)
538
            else:
539
                (input_tokens, input_positions, attn_metadata,
540
541
                 lora_index_mapping, lora_prompt_mapping,
                 lora_requests) = self._prepare_decode(seq_group_metadata_list)
542
                prompt_lens = []
543
                subquery_lens = None
544
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
545
546
                                                     prompt_lens,
                                                     subquery_lens)
547

548
549
            if self.lora_config:
                lora_mapping = LoRAMapping(
550
                    lora_index_mapping,
551
552
553
554
555
                    lora_prompt_mapping,
                )
            else:
                lora_mapping = None

556
557
558
559
560
561
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
562
563
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
564
            }
565
            metadata_dict.update(attn_metadata.asdict_zerocopy())
566
            broadcast_tensor_dict(metadata_dict, src=0)
567
        else:
568
            metadata_dict = broadcast_tensor_dict(src=0)
569
570
571
572
573
574
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
575
            attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
576
577
578
579
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
580
                selected_token_indices=selected_token_indices,
581
                categorized_sample_indices=None,
Nick Hill's avatar
Nick Hill committed
582
                generators=None,
583
584
585
                perform_sampling=False,
            )

586
        return (input_tokens, input_positions, attn_metadata,
587
                sampling_metadata, lora_requests, lora_mapping)
588

589
590
591
    @torch.inference_mode()
    def execute_model(
        self,
592
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
593
        kv_caches: List[torch.Tensor],
594
    ) -> Optional[SamplerOutput]:
595
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
596
597
         lora_requests,
         lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
598
599
600
601

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

602
        # Execute the model.
603
        if attn_metadata.use_cuda_graph:
604
605
606
607
608
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
        hidden_states = model_executable(
609
610
611
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
612
            attn_metadata=attn_metadata,
613
614
        )

615
616
617
618
619
620
621
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
        if not sampling_metadata.perform_sampling:
            return None

622
623
        # Sample the next token.
        output = self.model.sample(
624
            logits=logits,
625
626
627
628
629
630
631
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
632
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
633
634
635
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
            for idx in range(self.lora_config.max_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_local_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
            dummy_lora_requests_per_seq = [
                dummy_lora_requests[idx % len(dummy_lora_requests)]
                for idx in range(max_num_seqs)
            ]

658
659
660
661
662
663
664
665
666
667
668
669
670
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
671
672
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
673
674
675
676
677
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
678
        kv_caches = [None] * num_layers
679
        self.execute_model(seqs, kv_caches)
680
        torch.cuda.synchronize()
681
682
        return

683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    def remove_all_loras(self) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_all_loras()

    def set_active_loras(self, lora_requests: List[LoRARequest],
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

709
    @torch.inference_mode()
710
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
711
712
713
714
715
716
717
718
719
720
721
722
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
723
724
        # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
        # deleted before the CUDA graphs.
725
        self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
Woosuk Kwon's avatar
Woosuk Kwon committed
726

727
728
729
730
731
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
732
733
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
734
735
736
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
737
738
739
740
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
741
742
743
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
744
745
746
747
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

748
749
750
751
752
753
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

Woosuk Kwon's avatar
Woosuk Kwon committed
754
755
756
757
758
759
        # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
        # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
        # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
        # We always prioritize using custom all-reduce kernel but fall back
        # to PyTorch or CuPy NCCL if it is disabled or not supported.
760
        with custom_all_reduce.capture():
761
762
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
763
            for batch_size in reversed(batch_size_capture_list):
764
765
                # Create dummy attn_metadata.
                attn_metadata = self.attn_backend.make_metadata(
766
767
768
                    is_prompt=False,
                    slot_mapping=slot_mapping[:batch_size],
                    prompt_lens=None,
769
770
771
772
                    prompt_lens_tensor=None,
                    num_prompt_tokens=0,
                    num_generation_tokens=batch_size,
                    max_subquery_len=None,
773
                    max_context_len=self.max_context_len_to_capture,
774
                    max_prompt_len=None,
775
776
                    subquery_start_loc=None,
                    seq_start_loc=None,
777
778
779
                    context_lens=context_lens[:batch_size],
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
780
                    kv_cache_dtype=self.kv_cache_dtype,
781
                )
782

783
784
785
786
787
788
789
790
791
792
793
794
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
795
                    attn_metadata,
796
                    memory_pool=self.graph_memory_pool,
797
                )
798
799
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
800
801
802
803
804
805

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")

Woosuk Kwon's avatar
Woosuk Kwon committed
806
807
808
809
810
811
812
813
    def __del__(self) -> None:
        # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
        # NOTE(woosuk): This is necessary because otherwise deadlocks can
        # happen.
        # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
        self.graph_runners.clear()
        self.cupy_nccl_backend = None

814
815
816
817
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

818
819
820
821
822
823
824
825
826
827
828
829
830

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
831
832
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
833
834
835
836
837
838
        memory_pool,
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
839
        with _maybe_cupy_nccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
840
            self.model(
841
842
843
                input_ids,
                positions,
                kv_caches,
844
                attn_metadata,
845
846
847
            )
        torch.cuda.synchronize()

Woosuk Kwon's avatar
Woosuk Kwon committed
848
849
850
851
852
        # Capture the graph.
        # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
        # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
853
            with _maybe_cupy_nccl():
Woosuk Kwon's avatar
Woosuk Kwon committed
854
855
856
857
                hidden_states = self.model(
                    input_ids,
                    positions,
                    kv_caches,
858
                    attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
859
860
861
                )
        torch.cuda.synchronize()

862
863
864
865
866
        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
867
868
869
            "slot_mapping": attn_metadata.slot_mapping,
            "context_lens": attn_metadata.context_lens,
            "block_tables": attn_metadata.block_tables,
870
871
872
873
874
875
876
877
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
878
879
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
880
881
882
883
884
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
885
886
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
887
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
888
                                                 non_blocking=True)
889
        self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
890
                                                 non_blocking=True)
891
        self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
892
                                                 non_blocking=True)
893
894
895
896
897
898
899
900
901
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

902

903
904
905
906
907
908
909
910
911
@contextlib.contextmanager
def _maybe_cupy_nccl():
    if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
        with with_cupy_nccl_for_all_reduce():
            yield
    else:
        yield


912
def _get_graph_batch_size(batch_size: int) -> int:
913
914
915
916
917
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
918
919
920
921
922
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
923
924
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)