"vllm/sequence.py" did not exist on "e86717833da1216222cf0d490c2e3ba198610b13"
model_runner.py 28.2 KB
Newer Older
1
import time
2
from typing import Dict, List, Optional, Tuple, Union
3

4
import numpy as np
5
import torch
6
import torch.nn as nn
7
8
9
10

from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
11
from vllm.model_executor.parallel_utils.communication_op import (
12
    broadcast_tensor_dict)
13
14
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
15
from vllm.utils import in_wsl
16
17
18

logger = init_logger(__name__)

19
KVCache = Tuple[torch.Tensor, torch.Tensor]
20
_PAD_SLOT_ID = -1
21
22
23
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
24
25
26
27
28
29
30
31
32


class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
33
        is_driver_worker: bool = False,
34
35
36
37
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
38
        self.is_driver_worker = is_driver_worker
39

Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
43
        # model_config can be None in tests/samplers/test_sampler.py.
        # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
        self.sliding_window = (model_config.get_sliding_window()
                               if model_config is not None else None)
44
45
46
        self.model = None
        self.block_size = None  # Set after initial profiling.

47
48
49
50
51
52
53
54
55
56
57
58
59
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
        self.graph_memory_pool = None  # Set during graph capture.

        self.max_context_len_to_capture = (
            self.model_config.max_context_len_to_capture
            if self.model_config is not None else 0)
        # When using CUDA graph, the input block tables must be padded to
        # max_context_len_to_capture. However, creating the block table in
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
        self.graph_block_tables = None  # Set after initial profiling.
60
61
        # cache in_wsl result
        self.in_wsl = in_wsl()
62

63
64
65
66
67
68
    def load_model(self) -> None:
        self.model = get_model(self.model_config)

    def set_block_size(self, block_size: int) -> None:
        self.block_size = block_size

69
70
71
72
73
        max_num_blocks = (self.max_context_len_to_capture + block_size -
                          1) // block_size
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)

74
75
76
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
77
78
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int],
               List[int]]:
79
80
81
82
83
84
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []

        prompt_lens: List[int] = []
85
86
87
        context_lens: List[int] = []
        subquery_lens: List[int] = []
        prefix_block_tables: List[List[int]] = []
88
89
90
91
92
93
94
95
96
97
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)
98
99
100
101
102
103
104
105
106
107
108
            prefix_len = 0
            prefix = seq_group_metadata.prefix
            if prefix is not None and prefix.computed:
                prefix_len = prefix.get_length()
                prompt_tokens = prompt_tokens[prefix_len:]
                prefix_block_tables.append(prefix.get_block_numbers())
            else:
                prefix_block_tables.append([])
            # actual prompt lens
            context_lens.append(prefix_len)
            subquery_lens.append(prompt_len - prefix_len)
109
110
111
112

            input_tokens.append(prompt_tokens)
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
113
114
            input_positions.append(
                list(range(prefix_len, prefix_len + len(prompt_tokens))))
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

            if seq_group_metadata.block_tables is None:
                # During memory profiling, the block tables are not initialized
                # yet. In this case, we just use a dummy slot mapping.
                slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
                continue

            # Compute the slot mapping.
            slot_mapping.append([])
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
            # where start_idx is max(0, prompt_len - sliding_window).
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
132
133
134
                assert prefix_len == 0, (
                    "Prefix caching is currently not supported with "
                    "sliding window attention")
135
                start_idx = max(0, prompt_len - self.sliding_window)
136
            for i in range(prefix_len, prompt_len):
137
138
139
140
141
142
143
144
145
                if i < start_idx:
                    slot_mapping[-1].append(_PAD_SLOT_ID)
                    continue

                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping[-1].append(slot)

146
        max_prompt_len = max(subquery_lens)
147
148
149
150
151
152
153
154
155
156
157
158
        input_tokens = _make_tensor_with_pad(input_tokens,
                                             max_prompt_len,
                                             pad=0,
                                             dtype=torch.long)
        input_positions = _make_tensor_with_pad(input_positions,
                                                max_prompt_len,
                                                pad=0,
                                                dtype=torch.long)
        slot_mapping = _make_tensor_with_pad(slot_mapping,
                                             max_prompt_len,
                                             pad=_PAD_SLOT_ID,
                                             dtype=torch.long)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
                                           device='cuda')
        # Prepare prefix block tables
        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
        block_tables = _make_tensor_with_pad(
            prefix_block_tables,
            max_len=max_prompt_block_table_len,
            pad=0,
            dtype=torch.int,
        )
        start_loc_tensor = torch.arange(0,
                                        len(prompt_lens) * max_prompt_len,
                                        max_prompt_len,
                                        dtype=torch.long,
                                        device='cuda')
        prompt_lens_tensor = torch.tensor(prompt_lens,
                                          dtype=torch.long,
                                          device='cuda')
178
179

        input_metadata = InputMetadata(
180
            is_prompt=True,
181
            slot_mapping=slot_mapping,
182
183
184
            prompt_lens=prompt_lens_tensor,
            max_seq_len=max_prompt_len,
            start_loc=start_loc_tensor,
185
            max_context_len=None,
186
187
            context_lens=context_lens_tensor,
            block_tables=block_tables,
188
            use_cuda_graph=False,
189
        )
190
191
        return (input_tokens, input_positions, input_metadata, prompt_lens,
                subquery_lens)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
        context_lens: List[int] = []
        block_tables: List[List[int]] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

213
214
                seq_len = seq_data.get_len()
                position = seq_len - 1
215
216
                input_positions.append([position])

217
218
219
220
                context_len = seq_len if self.sliding_window is None else min(
                    seq_len, self.sliding_window)
                context_lens.append(context_len)

221
222
223
224
225
226
227
228
229
230
231
232
                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append([slot])

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        batch_size = len(input_tokens)
        max_context_len = max(context_lens)
        use_captured_graph = (
            not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_context_len <= self.max_context_len_to_capture)
        if use_captured_graph:
            # Pad the input tokens, positions, and slot mapping to match the
            # batch size of the captured graph.
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
                input_tokens.append([])
                input_positions.append([])
                slot_mapping.append([])
                context_lens.append(1)
                block_tables.append([])
            batch_size = graph_batch_size

252
253
254
        input_tokens = _make_tensor_with_pad(input_tokens,
                                             max_len=1,
                                             pad=0,
255
                                             dtype=torch.long,
256
                                             device="cuda")
257
258
259
        input_positions = _make_tensor_with_pad(input_positions,
                                                max_len=1,
                                                pad=0,
260
                                                dtype=torch.long,
261
                                                device="cuda")
262
263
264
        slot_mapping = _make_tensor_with_pad(slot_mapping,
                                             max_len=1,
                                             pad=_PAD_SLOT_ID,
265
                                             dtype=torch.long,
266
                                             device="cuda")
267
268
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
269
                                    device="cuda")
270
271
272
273
274
275
276
277

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
278
            block_tables = torch.tensor(input_block_tables, device="cuda")
279
        else:
280
281
            max_block_table_len = (max_context_len + self.block_size -
                                   1) // self.block_size
282
283
            block_tables = _make_tensor_with_pad(
                block_tables,
284
                max_len=max_block_table_len,
285
286
                pad=0,
                dtype=torch.int,
287
                device="cuda",
288
            )
289
290

        input_metadata = InputMetadata(
291
            is_prompt=False,
292
            slot_mapping=slot_mapping,
293
294
295
            prompt_lens=None,
            max_seq_len=None,
            start_loc=None,
296
297
298
            max_context_len=max_context_len,
            context_lens=context_lens,
            block_tables=block_tables,
299
            use_cuda_graph=use_captured_graph,
300
301
302
303
304
305
306
        )
        return input_tokens, input_positions, input_metadata

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
307
        subquery_lens: Optional[List[int]],
308
309
310
311
312
313
314
    ) -> SamplingMetadata:
        seq_groups: List[Tuple[List[int], SamplingParams]] = []
        selected_token_indices: List[int] = []
        selected_token_start_idx = 0
        categorized_sample_indices = {t: [] for t in SamplingType}
        categorized_sample_indices_start_idx = 0

315
        max_subquery_len = max(subquery_lens) if subquery_lens else 1
316
317
318
319
320
321
322
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
            seq_groups.append((seq_ids, sampling_params))

            if seq_group_metadata.is_prompt:
                assert len(seq_ids) == 1
323
324
                assert subquery_lens is not None
                subquery_len = subquery_lens[i]
325
326
                if sampling_params.prompt_logprobs is not None:
                    # NOTE: prompt token positions do not need sample, skip
327
                    categorized_sample_indices_start_idx += subquery_len - 1
328
329
330
331
332
333
334
335
336

                categorized_sample_indices[
                    sampling_params.sampling_type].append(
                        categorized_sample_indices_start_idx)
                categorized_sample_indices_start_idx += 1

                if sampling_params.prompt_logprobs is not None:
                    selected_token_indices.extend(
                        range(selected_token_start_idx,
337
                              selected_token_start_idx + subquery_len - 1))
338
                selected_token_indices.append(selected_token_start_idx +
339
340
                                              subquery_len - 1)
                selected_token_start_idx += max_subquery_len
341
342
343
344
345
346
347
348
349
350
351
352
353
            else:
                num_seqs = len(seq_ids)
                selected_token_indices.extend(
                    range(selected_token_start_idx,
                          selected_token_start_idx + num_seqs))
                selected_token_start_idx += num_seqs

                categorized_sample_indices[
                    sampling_params.sampling_type].extend(
                        range(categorized_sample_indices_start_idx,
                              categorized_sample_indices_start_idx + num_seqs))
                categorized_sample_indices_start_idx += num_seqs

354
355
356
        selected_token_indices = _async_h2d(selected_token_indices,
                                            dtype=torch.long,
                                            pin_memory=not self.in_wsl)
357
        categorized_sample_indices = {
358
            t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
            for t, seq_ids in categorized_sample_indices.items()
        }

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
        )
        return sampling_metadata

375
376
377
378
379
380
381
382
383
384
    def prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
385
386
                (input_tokens, input_positions, input_metadata, prompt_lens,
                 subquery_lens) = self._prepare_prompt(seq_group_metadata_list)
387
388
389
            else:
                (input_tokens, input_positions, input_metadata
                 ) = self._prepare_decode(seq_group_metadata_list)
390
                subquery_lens = None
391
392
                prompt_lens = []
            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
393
394
                                                     prompt_lens,
                                                     subquery_lens)
395

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "is_prompt": input_metadata.is_prompt,
                "slot_mapping": input_metadata.slot_mapping,
                "prompt_lens": input_metadata.prompt_lens,
                "max_seq_len": input_metadata.max_seq_len,
                "start_loc": input_metadata.start_loc,
                "max_context_len": input_metadata.max_context_len,
                "context_lens": input_metadata.context_lens,
                "block_tables": input_metadata.block_tables,
                "use_cuda_graph": input_metadata.use_cuda_graph,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
411
            }
412
            broadcast_tensor_dict(metadata_dict, src=0)
413
        else:
414
415
416
            metadata_dict = broadcast_tensor_dict(src=0)
            input_tokens = metadata_dict["input_tokens"]
            input_positions = metadata_dict["input_positions"]
417
            input_metadata = InputMetadata(
418
419
420
421
422
423
424
425
426
                is_prompt=metadata_dict["is_prompt"],
                slot_mapping=metadata_dict["slot_mapping"],
                prompt_lens=metadata_dict["prompt_lens"],
                max_seq_len=metadata_dict["max_seq_len"],
                start_loc=metadata_dict["start_loc"],
                max_context_len=metadata_dict["max_context_len"],
                context_lens=metadata_dict["context_lens"],
                block_tables=metadata_dict["block_tables"],
                use_cuda_graph=metadata_dict["use_cuda_graph"],
427
428
429
430
431
            )
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
                prompt_lens=None,
432
                selected_token_indices=metadata_dict["selected_token_indices"],
433
434
435
436
437
438
                categorized_sample_indices=None,
                perform_sampling=False,
            )

        return input_tokens, input_positions, input_metadata, sampling_metadata

439
440
441
    @torch.inference_mode()
    def execute_model(
        self,
442
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
443
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
444
445
446
    ) -> Optional[SamplerOutput]:
        input_tokens, input_positions, input_metadata, sampling_metadata = (
            self.prepare_input_tensors(seq_group_metadata_list))
447
        # Execute the model.
448
449
450
451
452
453
        if input_metadata.use_cuda_graph:
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
        hidden_states = model_executable(
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            input_metadata=input_metadata,
        )

        # Sample the next token.
        output = self.model.sample(
            hidden_states=hidden_states,
            sampling_metadata=sampling_metadata,
        )
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
        vocab_size = self.model_config.get_vocab_size()
        sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
            seq_data = SequenceData([0] * seq_len)
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
        kv_caches = [(None, None)] * num_layers
        self.execute_model(seqs, kv_caches)
495
        torch.cuda.synchronize()
496
497
        return

498
499
500
501
502
503
504
    @torch.inference_mode()
    def capture_model(self, kv_caches: List[KVCache]) -> None:
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
505
506
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
507
508
509
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
510
511
512
513
514
515
516
517
518
519
520
521
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
        input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, 1,
                                      dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
        slot_mapping.fill_(_PAD_SLOT_ID)
        context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

522
523
524
525
526
527
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

528
529
        # NOTE: Capturing the largest batch size first may help reduce the
        # memory usage of CUDA graph.
530
        for batch_size in reversed(batch_size_capture_list):
531
532
            # Create dummy input_metadata.
            input_metadata = InputMetadata(
533
                is_prompt=False,
534
                slot_mapping=slot_mapping[:batch_size],
535
536
537
                prompt_lens=None,
                max_seq_len=None,
                start_loc=None,
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
                max_context_len=self.max_context_len_to_capture,
                context_lens=context_lens[:batch_size],
                block_tables=block_tables[:batch_size],
                use_cuda_graph=True,
            )

            graph_runner = CUDAGraphRunner(self.model)
            graph_runner.capture(
                input_tokens[:batch_size],
                input_positions[:batch_size],
                kv_caches,
                input_metadata,
                memory_pool=self.graph_memory_pool,
            )
            self.graph_memory_pool = graph_runner.graph.pool()
            self.graph_runners[batch_size] = graph_runner

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
        logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")


class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.graph = None
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        memory_pool,
    ) -> None:
        assert self.graph is None
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
581
582
583
584
585
586
587
588
589
590
591
592
        self.model(
            input_ids,
            positions,
            kv_caches,
            input_metadata,
        )
        torch.cuda.synchronize()

        # Capture the graph.
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph, pool=memory_pool):
            hidden_states = self.model(
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
                input_ids,
                positions,
                kv_caches,
                input_metadata,
            )
        torch.cuda.synchronize()

        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
            "slot_mapping": input_metadata.slot_mapping,
            "context_lens": input_metadata.context_lens,
            "block_tables": input_metadata.block_tables,
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
623
624
625
626
627
628
629
630
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
        self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
                                                 non_blocking=True)
        self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
                                                 non_blocking=True)
        self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
                                                 non_blocking=True)
631
632
633
634
635
636
637
638
639
640

        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

641
642
643
644
645
646
647
648
649
650
651

def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
    assert len(x) <= max_len
    return x + [pad] * (max_len - len(x))


def _make_tensor_with_pad(
    x: List[List[int]],
    max_len: int,
    pad: int,
    dtype: torch.dtype,
652
    device: Union[str, torch.device] = "cuda",
653
    pin_memory: bool = False,
654
655
) -> torch.Tensor:
    padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
656
657
658
659
    return torch.tensor(padded_x,
                        dtype=dtype,
                        device=device,
                        pin_memory=pin_memory and str(device) == "cpu")
660
661
662
663
664
665
666
667
668


def _get_graph_batch_size(batch_size: int) -> int:
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
        return (batch_size + 7) // 8 * 8
669
670
671
672
673


def _async_h2d(data: list, dtype, pin_memory):
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
    return t.to(device="cuda", non_blocking=True)