model_runner.py 71.9 KB
Newer Older
1
import dataclasses
2
import gc
3
import inspect
4
import itertools
5
import time
6
import warnings
7
import weakref
8
from dataclasses import dataclass
9
10
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
                    Tuple, Type, TypeVar, Union)
11

12
import numpy as np
13
import torch
14
import torch.distributed
15
import torch.nn as nn
16

17
import vllm.envs as envs
18
from vllm.attention import AttentionMetadata, get_attn_backend
19
20
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
21
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
22
23
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PromptAdapterConfig, SchedulerConfig)
24
from vllm.distributed import get_pp_group
25
from vllm.distributed.parallel_state import graph_capture
26
from vllm.inputs import INPUT_REGISTRY, InputRegistry
27
from vllm.logger import init_logger
28
29
30
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
31
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
32
from vllm.model_executor.model_loader import get_model
33
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
34
from vllm.model_executor.models.interfaces import (supports_lora,
35
                                                   supports_multimodal)
36
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
37
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
38
                             MultiModalInputs, MultiModalRegistry)
39
40
41
42
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
    LRUCacheWorkerPromptAdapterManager)
43
from vllm.sampling_params import SamplingParams
44
45
from vllm.sequence import (IntermediateTensors, SamplerOutput,
                           SequenceGroupMetadata)
46
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
47
48
                        flatten_2d_lists, is_hip, is_pin_memory_available,
                        supports_dynamo)
49
from vllm.worker.model_runner_base import (
50
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
51
52
53
54
55
56
57
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
58
59
60

logger = init_logger(__name__)

61
LORA_WARMUP_RANK = 8
62
63
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
64
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
65
66
67
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
68
_NUM_WARMUP_ITERS = 2
69

70
71
72
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")


73
@dataclass(frozen=True)
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class ModelInputForGPU(ModelRunnerInputBase):
    """
    This base class contains metadata needed for the base model forward pass
    but not metadata for possible additional steps, e.g., sampling. Model
    runners that run additional steps should subclass this method to add
    additional fields.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
    lora_mapping: Optional["LoRAMapping"] = None
    lora_requests: Optional[Set[LoRARequest]] = None
    attn_metadata: Optional["AttentionMetadata"] = None
88
89
    prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
    prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
90
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
Mor Zusman's avatar
Mor Zusman committed
91
92
    request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
    finished_requests_ids: Optional[List[str]] = None
93
    virtual_engine: int = 0
94
    async_callback: Optional[Callable] = None
95
    use_async_and_multi_step: bool = False
96
97
98
99
100
101
102
103

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
104
105
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
106
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
107
108
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
109
110
111
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict
112
113

    @classmethod
114
115
116
117
118
119
120
121
122
123
124
    def from_broadcasted_tensor_dict(
        cls: Type[TModelInputForGPU],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> TModelInputForGPU:
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


125
@dataclass(frozen=True)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
    # Used for speculative decoding. We do not broadcast it because it is only
    # used by the driver worker.
    is_prompt: Optional[bool] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
142
143
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
144
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
145
146
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
147
148
149
150
151
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
152

153
154
155
156
157
158
159
160
161
162
163
164
165
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForGPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


166
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
167
168
    """Build ModelInputForGPU from SequenceGroupMetadata."""

169
170
171
    # Note: ideally we would be using a dataclass(kw_only=True)
    # here, so that this can be subclassed easily,
    # but kw_only is not supported in python<3.10.
172
173
    class InterDataForSeqGroup:
        """Intermediate data for the current sequence group."""
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
        def simple_reinit(self):
            self.input_tokens[0].clear()  # type: ignore
            self.input_positions[0].clear()  # type: ignore
            self.seq_lens[0] = 0  # type: ignore
            self.orig_seq_lens[0] = 0  # type: ignore
            self.query_lens[0] = 0  # type: ignore
            self.context_lens[0] = 0  # type: ignore
            self.curr_sliding_window_blocks[0] = 0  # type: ignore
            self.lora_index_mapping.clear()  # type: ignore
            self.lora_prompt_mapping.clear()  # type: ignore
            self.lora_requests.clear()  # type: ignore
            self.prompt_adapter_index_mapping.clear()  # type: ignore
            self.prompt_adapter_prompt_mapping.clear()  # type: ignore

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        def __init__(
            self,
            *,
            # From sequence group metadata.
            request_id: str,
            seq_ids: List[int],
            is_prompt: bool,
            block_tables: Optional[Dict[int, List[int]]],
            computed_block_nums: List[int],
            n_seqs: int = 0,

            # Input tokens and positions.
            input_tokens: Optional[List[List[int]]] = None,
            input_positions: Optional[List[List[int]]] = None,

            # The sequence length (may be capped to the sliding window).
            seq_lens: Optional[List[int]] = None,
            # The original sequence length (before applying sliding window).
            # This is used to compute slot mapping.
            orig_seq_lens: Optional[List[int]] = None,
            # The query length.
            query_lens: Optional[List[int]] = None,
            # The number of tokens that are already computed.
            context_lens: Optional[List[int]] = None,
            # The current sliding window block.
            curr_sliding_window_blocks: Optional[List[int]] = None,

            # LoRA inputs.
            lora_index_mapping: Optional[List[List[int]]] = None,
            lora_prompt_mapping: Optional[List[List[int]]] = None,
            lora_requests: Optional[Set[LoRARequest]] = None,

            # Prompt adapter inputs.
            prompt_adapter_index_mapping: Optional[List[int]] = None,
            prompt_adapter_prompt_mapping: Optional[List[int]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,

            # Multi-modal inputs.
            multi_modal_inputs: Optional[MultiModalInputs] = None,

            # Whether the prefix cache is hit (prefill only).
            prefix_cache_hit: bool = False,
231
232
            reinit: bool = False,
            reinit_use_defaults: bool = False,
233
        ):
234
235
236
237
238
239
240
            if reinit:
                assert len(self.seq_ids) == len(seq_ids)  # type: ignore
                for i, seq_id in enumerate(seq_ids):
                    self.seq_ids[i] = seq_id  # type: ignore
            else:
                self.seq_ids = seq_ids

241
242
243
244
245
246
            self.request_id = request_id
            self.is_prompt = is_prompt
            self.block_tables = block_tables
            self.computed_block_nums = computed_block_nums
            self.n_seqs = n_seqs

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            if reinit:
                if len(self.seq_ids) == 1 and reinit_use_defaults:
                    self.simple_reinit()
                else:
                    if input_tokens:
                        self.input_tokens = input_tokens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.input_tokens[seq_id].clear()

                    if input_positions:
                        self.input_positions = input_positions
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.input_positions[seq_id].clear()

                    if seq_lens:
                        self.seq_lens = seq_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.seq_lens[seq_id] = 0

                    if orig_seq_lens:
                        self.orig_seq_lens = orig_seq_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.orig_seq_lens[seq_id] = 0

                    if query_lens:
                        self.query_lens = query_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.query_lens[seq_id] = 0

                    if context_lens:
                        self.context_lens = context_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.context_lens[seq_id] = 0

                    if curr_sliding_window_blocks:
                        self.curr_sliding_window_blocks = \
                            curr_sliding_window_blocks
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.curr_sliding_window_blocks[seq_id] = 0

                    if lora_index_mapping:
                        self.lora_index_mapping = lora_index_mapping
                    else:
                        self.lora_index_mapping.clear()

                    if lora_prompt_mapping:
                        self.lora_prompt_mapping = lora_prompt_mapping
                    else:
                        self.lora_prompt_mapping.clear()

                    if lora_requests:
                        self.lora_requests = lora_requests
                    else:
                        self.lora_requests.clear()

                    if prompt_adapter_index_mapping:
                        self.prompt_adapter_index_mapping = \
                            prompt_adapter_index_mapping
                    else:
                        self.prompt_adapter_index_mapping.clear()

                    if prompt_adapter_prompt_mapping:
                        self.prompt_adapter_prompt_mapping = \
                            prompt_adapter_prompt_mapping
                    else:
                        self.prompt_adapter_prompt_mapping.clear()

            else:
                self.input_tokens = input_tokens or []
                self.input_positions = input_positions or []
                self.seq_lens = seq_lens or []
                self.orig_seq_lens = orig_seq_lens or []
                self.query_lens = query_lens or []
                self.context_lens = context_lens or []
                self.curr_sliding_window_blocks = \
                    curr_sliding_window_blocks or []

                self.lora_index_mapping = lora_index_mapping or []
                self.lora_prompt_mapping = lora_prompt_mapping or []
                self.lora_requests = lora_requests or set()

                self.prompt_adapter_index_mapping = (
                    prompt_adapter_index_mapping or [])
                self.prompt_adapter_prompt_mapping = (
                    prompt_adapter_prompt_mapping or [])

            self.prompt_adapter_request = prompt_adapter_request
341
342
343
            self.multi_modal_inputs = multi_modal_inputs
            self.prefix_cache_hit = prefix_cache_hit

344
345
            self.n_seqs = len(self.seq_ids)

346
347
            if not reinit:
                self.__post_init__()
348
349
350
351
352
353
354
355
356
357
358
359

        def __post_init__(self):
            self.n_seqs = len(self.seq_ids)

            self.input_tokens = [[] for _ in range(self.n_seqs)]
            self.input_positions = [[] for _ in range(self.n_seqs)]
            self.seq_lens = [0] * self.n_seqs
            self.orig_seq_lens = [0] * self.n_seqs
            self.query_lens = [0] * self.n_seqs
            self.context_lens = [0] * self.n_seqs
            self.curr_sliding_window_blocks = [0] * self.n_seqs

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
            self.lora_index_mapping = []
            self.lora_prompt_mapping = []

    def gen_inter_data_builder(self, num_seqs: int):
        return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
            request_id="",
            seq_ids=[0] * num_seqs,
            is_prompt=True,
            block_tables=None,
            computed_block_nums=[])

    def init_cached_inter_data(self, *args, **kwargs):
        assert len(args) == 0
        assert "seq_ids" in kwargs
        seq_ids = kwargs["seq_ids"]
        num_seqs = len(seq_ids)

        # The inter-data cache is per model_runner
        inter_data_cache = self.runner.inter_data_cache
        if num_seqs not in inter_data_cache:
            inter_data_cache[num_seqs] = PyObjectCache(
                self.gen_inter_data_builder(num_seqs))

        obj = inter_data_cache[num_seqs].get_object()
        obj.__init__(*args, **kwargs)
        return obj

    def reset_cached_inter_data(self):
        for cache in self.runner.inter_data_cache.values():
            cache.reset()
390
391
392
393
394

    def __init__(self,
                 runner: "GPUModelRunnerBase",
                 finished_requests_ids: Optional[List[str]] = None):
        super().__init__()
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        # Compute functions for each sequence in a sequence group.
        # WARNING: The order of the functions matters!
        self.per_seq_compute_fns = [
            self._compute_lens,
            self._compute_for_prefix_cache_hit,
            self._compute_for_sliding_window,
            self._compute_lora_input,
        ]
        # Compute functions for each sequence group.
        # WARNING: The order of the functions matters!
        self.per_seq_group_compute_fns = [
            self._compute_prompt_adapter_input,
            self._compute_multi_modal_input,
        ]

410
411
412
413
414
415
416
417
418
419
420
421
422
        self.runner = runner
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.scheduler_config = self.runner.scheduler_config
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.enable_lora = self.runner.lora_config is not None
        self.enable_prompt_adapter = (self.runner.prompt_adapter_config
                                      is not None)
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
        self.finished_requests_ids = finished_requests_ids
        self.decode_only = True

423
424
425
426
        # Intermediate data (data in CPU before going to GPU) for
        # the current sequence group.
        self.inter_data_list: List[
            ModelInputForGPUBuilder.InterDataForSeqGroup] = []
427
428
429

        # Attention metadata inputs.
        self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
430
            weakref.proxy(self))
431
432
433
434
435
436
437
438
439
440
441

        # Engine/Model configurations.
        self.chunked_prefill_enabled = (
            self.scheduler_config is not None
            and self.scheduler_config.chunked_prefill_enabled)
        if self.sliding_window is not None:
            self.sliding_window_blocks = (
                self.sliding_window + self.block_size - 1) // self.block_size
            self.block_aligned_sliding_window = \
                self.sliding_window_blocks * self.block_size

442
443
444
445
446
447
448
    def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
                      seq_group_metadata: SequenceGroupMetadata):
        """Compute context length, sequence length and tokens
        for the given sequence data.
        """
        seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
        token_chunk_size = seq_group_metadata.token_chunk_size
449

450
451
452
453
454
455
456
457
458
459
460
461
462
463
        # Compute context length (the number of tokens that are
        # already computed) and sequence length (total number of tokens).
        seq_len = seq_data.get_len()
        if inter_data.is_prompt:
            context_len = seq_data.get_num_computed_tokens()
        else:
            # get_num_computed_tokens is incorrect for spec decoding.
            # So, we should have a special logic here.
            # TODO(sang): Fix it.
            context_len = seq_len - 1
        seq_len = min(seq_len, context_len + token_chunk_size)

        # Compute tokens.
        if inter_data.is_prompt:
464
465
466
            tokens = seq_data.get_token_ids()
            if context_len != 0 or seq_len < len(tokens):
                tokens = tokens[context_len:seq_len]
467
468
469
        else:
            # Optimization. get_token_ids requires the entire copy of
            # tokens.
470
            tokens = seq_data.get_last_token_id()
471
472
473
474

        inter_data.seq_lens[seq_idx] = seq_len
        inter_data.orig_seq_lens[seq_idx] = seq_len
        inter_data.context_lens[seq_idx] = context_len
475
476
477
478
479
480
481
482
483
484
485
486

        if isinstance(tokens, list):
            inter_data.input_tokens[seq_idx].extend(tokens)
        else:
            inter_data.input_tokens[seq_idx].append(tokens)

        if (seq_len - context_len) == 1:
            inter_data.input_positions[seq_idx].append(seq_len - 1)
        else:
            inter_data.input_positions[seq_idx].extend(
                range(context_len, seq_len))

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        inter_data.query_lens[
            seq_idx] = seq_len - context_len if inter_data.is_prompt else 1

    def _compute_for_prefix_cache_hit(
            self, inter_data: InterDataForSeqGroup, seq_idx: int,
            seq_group_metadata: SequenceGroupMetadata):
        """Check if hit prefix cache (i.e., some blocks are already computed).
        If hit, update input tokens and positions to only compute the
        remaining blocks.
        """
        computed_block_nums = inter_data.computed_block_nums

        # Note that prefix caching does not support sliding window.
        prefix_cache_hit = (computed_block_nums is not None
                            and len(computed_block_nums) > 0
                            and self.sliding_window is None
                            and inter_data.is_prompt)
        inter_data.prefix_cache_hit = prefix_cache_hit
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

        if not prefix_cache_hit:
            return

        assert computed_block_nums is not None
        # The cache hit prompt tokens in this sequence. Note that
        # this may be larger than the sequence length if chunked
        # prefill is enabled.
        prefix_cache_len = len(computed_block_nums) * self.block_size
        # The number of so far computed prompt tokens in this sequence.
        context_len = inter_data.context_lens[seq_idx]
        # The total number of prompt tokens in this sequence.
        # When chunked prefill is enabled, this is the token number of
        # computed chunks + current chunk.
        seq_len = inter_data.seq_lens[seq_idx]
        if prefix_cache_len <= context_len:
            # We already passed the cache hit region,
            # so do normal computation.
            pass
        elif context_len < prefix_cache_len < seq_len:
            # Partial hit. Compute the missing part.
            uncomputed_start = prefix_cache_len - context_len
527
            inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
528
                seq_idx][uncomputed_start:]
529
            inter_data.input_positions[seq_idx] = inter_data.input_positions[
530
531
532
                seq_idx][uncomputed_start:]
            context_len = prefix_cache_len

533
534
535
            inter_data.context_lens[seq_idx] = context_len
            inter_data.query_lens[
                seq_idx] = inter_data.seq_lens[seq_idx] - context_len
536
537
538
539
540
541
542
543
544
545
546
        elif seq_len <= prefix_cache_len:
            # Full hit. Only compute the last token to avoid
            # erroneous behavior. FIXME: Ideally we should directly
            # mark all tokens as computed in the scheduler and do not
            # schedule this sequence, so this case should not happen.
            inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
                seq_idx][-1:]
            inter_data.input_positions[seq_idx] = inter_data.input_positions[
                seq_idx][-1:]
            inter_data.query_lens[seq_idx] = 1
            inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
547
548
549
550
551
552
553
554
555
556
557
558
559
560

    def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
                                    seq_idx: int,
                                    seq_group_metadata: SequenceGroupMetadata):
        """Update seq_len and curr_sliding_window_block for the given
        sequence data (only required by decoding) if sliding window is enabled.
        """
        curr_sliding_window_block = 0
        sliding_seq_len = inter_data.seq_lens[seq_idx]
        if not inter_data.is_prompt and self.sliding_window is not None:
            # TODO(sang): This is a hack to make sliding window work with
            # paged attn. We can remove it if we make paged attn kernel
            # to properly handle slinding window attn.
            curr_sliding_window_block = self.sliding_window_blocks
561
562
            if self.scheduler_config.use_v2_block_manager:
                # number of elements in last block
563
                suff_len = inter_data.seq_lens[seq_idx] % self.block_size
564
                sliding_seq_len = min(
565
566
                    inter_data.seq_lens[seq_idx],
                    self.block_aligned_sliding_window + suff_len)
567
                if suff_len > 0:
568
                    curr_sliding_window_block += 1
569
            else:
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
                sliding_seq_len = min(inter_data.seq_lens[seq_idx],
                                      self.sliding_window)

        inter_data.curr_sliding_window_blocks[
            seq_idx] = curr_sliding_window_block
        inter_data.seq_lens[seq_idx] = sliding_seq_len

    def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
                            seq_idx: int,
                            seq_group_metadata: SequenceGroupMetadata):
        """If LoRA is enabled, compute LoRA index and prompt mapping."""
        if not self.enable_lora:
            return

        lora_id = seq_group_metadata.lora_int_id
        if lora_id > 0:
            inter_data.lora_requests.add(seq_group_metadata.lora_request)
        query_len = inter_data.query_lens[seq_idx]
        inter_data.lora_index_mapping.append([lora_id] * query_len)
        inter_data.lora_prompt_mapping.append(
            [lora_id] *
            (query_len if seq_group_metadata.sampling_params
             and seq_group_metadata.sampling_params.prompt_logprobs is not None
             else 1))

    def _compute_prompt_adapter_input(
            self, inter_data: InterDataForSeqGroup,
            seq_group_metadata: SequenceGroupMetadata):
        """If prompt adapter is enabled, compute index and prompt mapping.
        """
        # Note that when is_prompt=True, we expect only one sequence
        # in the group.
        if not self.enable_prompt_adapter:
            return

        prompt_adapter_id = seq_group_metadata.prompt_adapter_id
        if prompt_adapter_id <= 0 or not inter_data.is_prompt:
            return

        # We expect only one sequence in the group when is_prompt=True.
        assert inter_data.n_seqs == 1
        query_len = inter_data.query_lens[0]
        inter_data.prompt_adapter_request = (
            seq_group_metadata.prompt_adapter_request)

        num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
        inter_data.prompt_adapter_index_mapping = [
            prompt_adapter_id
        ] * num_tokens + [0] * (query_len - num_tokens)
        inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
            query_len if seq_group_metadata.sampling_params
            and seq_group_metadata.sampling_params.prompt_logprobs else 1)

    def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
                                   seq_group_metadata: SequenceGroupMetadata):
        """If multi-modal data is given, add it to the input."""
        mm_data = seq_group_metadata.multi_modal_data
        if not mm_data:
            return

        mm_kwargs = self.multi_modal_input_mapper(mm_data)
        inter_data.multi_modal_inputs = mm_kwargs
632
633

    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
634
        """Add a sequence group to the builder."""
635
        seq_ids = seq_group_metadata.seq_data.keys()
636
637
638
639
640
641
642
        n_seqs = len(seq_ids)
        is_prompt = seq_group_metadata.is_prompt

        if is_prompt:
            assert n_seqs == 1
            self.decode_only = False

643
        inter_data = self.init_cached_inter_data(
644
645
646
647
            request_id=seq_group_metadata.request_id,
            seq_ids=seq_ids,
            is_prompt=is_prompt,
            block_tables=seq_group_metadata.block_tables,
648
649
650
651
            computed_block_nums=seq_group_metadata.computed_block_nums,
            reinit=True,
            reinit_use_defaults=True)

652
        self.inter_data_list.append(inter_data)
653

654
655
656
657
658
        for seq_idx in range(n_seqs):
            for per_seq_fn in self.per_seq_compute_fns:
                per_seq_fn(inter_data, seq_idx, seq_group_metadata)
        for per_seq_group_fn in self.per_seq_group_compute_fns:
            per_seq_group_fn(inter_data, seq_group_metadata)
659

660
661
662
663
664
665
    def _use_captured_graph(self, batch_size: int,
                            max_decode_seq_len: int) -> bool:
        return (self.decode_only and not self.runner.model_config.enforce_eager
                and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
                and max_decode_seq_len <= self.runner.max_seq_len_to_capture)

666
    def build(self) -> ModelInputForGPU:
667
668
669
670
        """Finalize the builder intermediate data and
        create on-device tensors.
        """
        # Combine and flatten intermediate data.
671
672
673
674
675
        input_tokens = []
        for inter_data in self.inter_data_list:
            for cur_input_tokens in inter_data.input_tokens:
                input_tokens.extend(cur_input_tokens)

676
677
678
        if not input_tokens:
            # This may happen when all prefill requests hit
            # prefix caching and there is no decode request.
679
            return self.model_input_cls()
680
681
682
683
684
685

        input_positions = []
        for inter_data in self.inter_data_list:
            for cur_input_positions in inter_data.input_positions:
                input_positions.extend(cur_input_positions)

686
687
688
689
690
691
692
        seq_lens = []
        max_decode_seq_len = 0
        for inter_data in self.inter_data_list:
            seq_lens.extend(inter_data.seq_lens)
            if not inter_data.is_prompt:
                max_decode_seq_len = max(max_decode_seq_len,
                                         max(inter_data.seq_lens))
693
694
695
696
        query_lens = []
        for inter_data in self.inter_data_list:
            query_lens.extend(inter_data.query_lens)

697
698
699
700
701
702
        # Mapping from request IDs to sequence IDs. Used for Jamba models
        # that manages the cache by itself.
        request_ids_to_seq_ids = {
            data.request_id: data.seq_ids
            for data in self.inter_data_list
        }
703

704
        batch_size = len(input_tokens)
705
706
        use_captured_graph = self._use_captured_graph(batch_size,
                                                      max_decode_seq_len)
707
708
709
710
711
712
713
714
715
716
717
718

        # If cuda graph can be used, pad tensors accordingly.
        # See `capture_model` API for more details.
        # vLLM uses cuda graph only for decoding requests.
        cuda_graph_pad_size = -1
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            cuda_graph_pad_size = graph_batch_size - batch_size
            batch_size = graph_batch_size

        # Tokens and positions.
719
720
721
        if cuda_graph_pad_size:
            input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
            input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
722
723
724
725
726
727
728
        assert self.runner.device is not None
        input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
                                               self.runner.device,
                                               self.runner.pin_memory)
        input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
                                                  self.runner.device,
                                                  self.runner.pin_memory)
729
730

        # Sequence and query lengths.
731
732
        if cuda_graph_pad_size:
            seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
733
734
735

        # Attention metadata.
        attn_metadata = self.attn_metadata_builder.build(
736
            seq_lens, query_lens, cuda_graph_pad_size, batch_size)
737
738

        # LoRA data.
739
740
        lora_requests = set()
        lora_mapping = None
741
        if self.enable_lora:
742
743
744
745
746
747
            lora_requests = set(r for data in self.inter_data_list
                                for r in data.lora_requests)
            lora_index_mapping = flatten_2d_lists([
                flatten_2d_lists(inter_data.lora_index_mapping)
                for inter_data in self.inter_data_list
            ])
748
749
750
            if cuda_graph_pad_size:
                lora_index_mapping.extend(
                    itertools.repeat(0, cuda_graph_pad_size))
751
752
753
754
            lora_prompt_mapping = flatten_2d_lists([
                flatten_2d_lists(inter_data.lora_prompt_mapping)
                for inter_data in self.inter_data_list
            ])
755

756
            lora_mapping = LoRAMapping(
757
758
759
                **dict(index_mapping=lora_index_mapping,
                       prompt_mapping=lora_prompt_mapping,
                       is_prefill=not self.decode_only))
760
761

        # Prompt adapter data.
762
763
        prompt_adapter_requests: Set[PromptAdapterRequest] = set()
        prompt_adapter_mapping = None
764
        if self.enable_prompt_adapter:
765
766
767
768
769
770
771
            prompt_adapter_requests = set(
                data.prompt_adapter_request for data in self.inter_data_list
                if data.prompt_adapter_request is not None)
            prompt_adapter_index_mapping = flatten_2d_lists([
                inter_data.prompt_adapter_index_mapping
                for inter_data in self.inter_data_list
            ])
772
773
774
            if cuda_graph_pad_size:
                prompt_adapter_index_mapping.extend(
                    itertools.repeat(0, cuda_graph_pad_size))
775
776
777
778
            prompt_adapter_prompt_mapping = flatten_2d_lists([
                inter_data.prompt_adapter_prompt_mapping
                for inter_data in self.inter_data_list
            ])
779
            prompt_adapter_mapping = PromptAdapterMapping(
780
781
                prompt_adapter_index_mapping,
                prompt_adapter_prompt_mapping,
782
783
784
            )

        # Multi-modal data.
785
786
787
788
        multi_modal_inputs_list = [
            data.multi_modal_inputs for data in self.inter_data_list
            if data.multi_modal_inputs is not None
        ]
789
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
790
791
792
793
794

        return self.model_input_cls(
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
            attn_metadata=attn_metadata,
795
796
            seq_lens=seq_lens,
            query_lens=query_lens,
797
            lora_mapping=lora_mapping,
798
            lora_requests=lora_requests,
799
            multi_modal_kwargs=multi_modal_kwargs,
800
            request_ids_to_seq_ids=request_ids_to_seq_ids,
801
802
            finished_requests_ids=self.finished_requests_ids,
            prompt_adapter_mapping=prompt_adapter_mapping,
803
            prompt_adapter_requests=prompt_adapter_requests)
804
805


806
807
808
809
810
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
    """
    Helper class for shared methods between GPU model runners.
    """
    _model_input_cls: Type[TModelInputForGPU]
811
    _builder_cls: Type[ModelInputForGPUBuilder]
812
813
814
815
816
817

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
818
        device_config: DeviceConfig,
819
        cache_config: CacheConfig,
820
        load_config: LoadConfig,
821
        lora_config: Optional[LoRAConfig],
822
        kv_cache_dtype: Optional[str] = "auto",
823
        is_driver_worker: bool = False,
824
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
825
        return_hidden_states: bool = False,
826
        observability_config: Optional[ObservabilityConfig] = None,
827
828
        input_registry: InputRegistry = INPUT_REGISTRY,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
829
830
831
832
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
833
834
        self.device_config = device_config
        self.cache_config = cache_config
835
        self.lora_config = lora_config
836
        self.load_config = load_config
837
        self.is_driver_worker = is_driver_worker
838
        self.prompt_adapter_config = prompt_adapter_config
839
        self.return_hidden_states = return_hidden_states
840
        self.observability_config = observability_config
841

842
        self.device = self.device_config.device
843
        self.pin_memory = is_pin_memory_available()
844

845
846
847
848
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
849
850
851
852

        self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
            {} for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
853
854
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
Mor Zusman's avatar
Mor Zusman committed
855
856
857
858

        self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
            parallel_config)

859
        # When using CUDA graph, the input block tables must be padded to
860
        # max_seq_len_to_capture. However, creating the block table in
861
862
863
864
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
865
866
867
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
868
869
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
870
        self.attn_backend = get_attn_backend(
871
            num_attn_heads,
872
873
874
875
876
877
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
878
        ) if num_attn_heads else None
879
880
881
882
883
        if self.attn_backend:
            self.attn_state = self.attn_backend.get_state_cls()(
                weakref.proxy(self))
        else:
            self.attn_state = CommonAttentionState(weakref.proxy(self))
884

885
        # Multi-modal data support
886
887
888
889
        self.input_registry = input_registry
        self.mm_registry = mm_registry
        self.multi_modal_input_mapper = mm_registry \
            .create_input_mapper(model_config)
890
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)
891

892
        # Lazy initialization
893
        self.model: nn.Module  # Set after load_model
894
895
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
896
        self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
897

898
899
900
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

901
902
903
904
905
        # Used to cache python objects
        self.inter_data_cache: Dict[int, PyObjectCache] = {}
        self.sampling_metadata_cache: SamplingMetadataCache = \
            SamplingMetadataCache()

906
    def load_model(self) -> None:
907
        logger.info("Starting to load model %s...", self.model_config.model)
908
        with CudaMemoryProfiler() as m:
909
910
911
912
913
914
915
            self.model = get_model(model_config=self.model_config,
                                   device_config=self.device_config,
                                   load_config=self.load_config,
                                   lora_config=self.lora_config,
                                   parallel_config=self.parallel_config,
                                   scheduler_config=self.scheduler_config,
                                   cache_config=self.cache_config)
916
917

        self.model_memory_usage = m.consumed_memory
918
919
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
920
921

        if self.lora_config:
922
            assert supports_lora(self.model), "Model does not support LoRA"
923
            assert not supports_multimodal(
924
                self.model
925
            ), "To be tested: Multi-modal model with LoRA settings."
926

927
928
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
929
930
931
932
933
934
935
936
937
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
938
            self.model = self.lora_manager.create_lora_manager(self.model)
939

940
941
942
943
944
945
946
947
948
        if self.prompt_adapter_config:
            self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens, self.device,
                self.prompt_adapter_config)
            self.model = (
                self.prompt_adapter_manager.create_prompt_adapter_manager(
                    self.model))

949
        if self.kv_cache_dtype == "fp8" and is_hip():
950
951
952
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
953
954
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
955
956
957
958
959
960
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
961
962
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
963
964
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
965
                else:
966
967
968
969
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
970
            else:
971
972
973
974
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
975

976
        if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
977
978
979
980
            self.model = torch.compile(self.model,
                                       fullgraph=True,
                                       backend="eager")

981
982
983
984
985
986
987
988
989
990
991
992
993
994
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

995
996
997
998
999
1000
1001
1002
1003
1004
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        from vllm.model_executor.model_loader.loader import TensorizerLoader
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

1005
1006
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
1007
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
1008

1009
    def _prepare_model_input_tensors(
1010
1011
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
Mor Zusman's avatar
Mor Zusman committed
1012
        finished_requests_ids: Optional[List[str]] = None
1013
1014
1015
1016
    ) -> TModelInputForGPU:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
1028
        builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
1029
        for seq_group_metadata in seq_group_metadata_list:
1030
            builder.add_seq_group(seq_group_metadata)
1031
1032
1033

        builder.reset_cached_inter_data()

1034
        return builder.build()  # type: ignore
1035

1036
1037
1038
    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
1039
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
1040
1041
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
1042
1043
1044
1045
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
1046
1047
        dummy_lora_requests: List[LoRARequest] = []
        dummy_lora_requests_per_seq: List[LoRARequest] = []
1048
        if self.lora_config:
1049
            assert self.lora_manager is not None
1050
1051
1052
1053
1054
1055
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
1056
                        lora_path="/not/a/real/path",
1057
1058
1059
1060
1061
1062
1063
1064
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
1065

1066
1067
1068
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
1069
1070
        # Additional GPU memory may be needed for multi-modal encoding, which
        # needs to be accounted for when calculating the GPU blocks for
1071
1072
1073
1074
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
1075

1076
1077
        max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
            self.model_config)
1078
        if max_mm_tokens > 0:
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
            max_num_seqs_orig = max_num_seqs
            max_num_seqs = min(max_num_seqs,
                               max_num_batched_tokens // max_mm_tokens)
            if max_num_seqs < 1:
                expr = (f"min({max_num_seqs_orig}, "
                        f"{max_num_batched_tokens} // {max_mm_tokens})")
                logger.warning(
                    "Computed max_num_seqs (%s) to be less than 1. "
                    "Setting it to the minimum value of 1.", expr)
                max_num_seqs = 1

1090
        batch_size = 0
1091
1092
1093
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
1094
            batch_size += seq_len
1095

1096
1097
1098
1099
            seq_data, dummy_multi_modal_data = self.input_registry \
                .dummy_data_for_profiling(self.model_config,
                                          seq_len,
                                          self.mm_registry)
1100

1101
1102
1103
1104
1105
1106
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
1107
1108
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
1109
                multi_modal_data=dummy_multi_modal_data,
1110
1111
1112
1113
1114
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
1115
        kv_caches = [None] * num_layers
Mor Zusman's avatar
Mor Zusman committed
1116
1117
1118
        finished_requests_ids = [seq.request_id for seq in seqs]
        model_input = self.prepare_model_input(
            seqs, finished_requests_ids=finished_requests_ids)
1119
1120
1121
1122
1123
1124
1125
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
        self.execute_model(model_input, kv_caches, intermediate_tensors)
1126
        torch.cuda.synchronize()
1127
1128
        return

1129
    def remove_all_loras(self):
1130
1131
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1132
        self.lora_manager.remove_all_adapters()
1133

1134
    def set_active_loras(self, lora_requests: Set[LoRARequest],
1135
1136
1137
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1138
        self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
1139
1140
1141
1142

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1143
        return self.lora_manager.add_adapter(lora_request)
1144
1145
1146
1147

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1148
        return self.lora_manager.remove_adapter(lora_id)
1149
1150
1151
1152

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1153
        return self.lora_manager.pin_adapter(lora_id)
1154
1155
1156
1157

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
        return self.lora_manager.list_adapters()

    def remove_all_prompt_adapters(self):
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.remove_all_adapters()

    def set_active_prompt_adapters(
            self, prompt_adapter_requests: Set[PromptAdapterRequest],
            prompt_adapter_mapping: PromptAdapterMapping) -> None:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.set_active_adapters(
            prompt_adapter_requests, prompt_adapter_mapping)

    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.list_adapters()
1193

1194
    @torch.inference_mode()
1195
    def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
1208
1209
1210
1211
1212
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
1213
1214
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
1215
1216
1217
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
1218
1219
1220
1221
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
1222
1223
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235

        # Prepare dummy previous_hidden_states only if needed by the model.
        # This is used by draft models such as EAGLE.
        previous_hidden_states = None
        if "previous_hidden_states" in inspect.signature(
                self.model.forward).parameters:
            previous_hidden_states = torch.empty(
                [max_batch_size,
                 self.model_config.get_hidden_size()],
                dtype=self.model_config.dtype,
                device=self.device)

1236
1237
1238
1239
1240
1241
        intermediate_inputs = None
        if not get_pp_group().is_first_rank:
            intermediate_inputs = self.model.make_empty_intermediate_tensors(
                batch_size=max_batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
1242

1243
1244
        # Prepare buffer for outputs. These will be reused for all batch sizes.
        # It will be filled after the first graph capture.
1245
1246
1247
        hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
            None
        ] * self.parallel_config.pipeline_parallel_size
1248

1249
1250
1251
1252
1253
1254
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

1255
1256
        with self.attn_state.graph_capture(
                max_batch_size), graph_capture() as graph_capture_context:
1257
1258
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
1259
1260
1261
            for virtual_engine in range(
                    self.parallel_config.pipeline_parallel_size):
                for batch_size in reversed(batch_size_capture_list):
1262
1263
1264
                    attn_metadata = (
                        self.attn_state.graph_capture_get_metadata_for_batch(
                            batch_size))
1265
1266
1267

                    if self.lora_config:
                        lora_mapping = LoRAMapping(
1268
1269
1270
                            **dict(index_mapping=[0] * batch_size,
                                   prompt_mapping=[0] * batch_size,
                                   is_prefill=False))
1271
1272
                        self.set_active_loras(set(), lora_mapping)

1273
1274
1275
1276
1277
1278
1279
1280
                    if self.prompt_adapter_config:
                        prompt_adapter_mapping = PromptAdapterMapping(
                            [-1] * batch_size,
                            [-1] * batch_size,
                        )
                        self.set_active_prompt_adapters(
                            set(), prompt_adapter_mapping)

1281
                    graph_runner = CUDAGraphRunner(
1282
1283
                        self.model, self.attn_backend.get_name(),
                        self.attn_state.graph_clone(batch_size))
1284

Mor Zusman's avatar
Mor Zusman committed
1285
1286
                    capture_inputs = {
                        "input_ids":
1287
                        input_tokens[:batch_size],
Mor Zusman's avatar
Mor Zusman committed
1288
                        "positions":
1289
                        input_positions[:batch_size],
Mor Zusman's avatar
Mor Zusman committed
1290
                        "hidden_or_intermediate_states":
1291
1292
1293
1294
1295
                        hidden_or_intermediate_states[
                            virtual_engine]  # type: ignore
                        [:batch_size]
                        if hidden_or_intermediate_states[virtual_engine]
                        is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1296
                        "intermediate_inputs":
1297
1298
                        intermediate_inputs[:batch_size]
                        if intermediate_inputs is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1299
                        "kv_caches":
1300
                        kv_caches[virtual_engine],
Mor Zusman's avatar
Mor Zusman committed
1301
                        "attn_metadata":
1302
                        attn_metadata,
Mor Zusman's avatar
Mor Zusman committed
1303
1304
1305
1306
1307
                        "memory_pool":
                        self.graph_memory_pool,
                        "stream":
                        graph_capture_context.stream
                    }
1308
1309
1310
1311
1312
                    if previous_hidden_states is not None:
                        capture_inputs[
                            "previous_hidden_states"] = previous_hidden_states[:
                                                                               batch_size]

Mor Zusman's avatar
Mor Zusman committed
1313
1314
1315
1316
1317
1318
1319
1320
                    if self.has_seqlen_agnostic:
                        # Only used by Mamba-based models CUDA graph atm (Jamba)
                        capture_inputs.update({
                            "seqlen_agnostic_capture_inputs":
                            self.model.get_seqlen_agnostic_capture_inputs(
                                batch_size)
                        })
                    graph_runner.capture(**capture_inputs)
1321
1322
1323
                    self.graph_memory_pool = graph_runner.graph.pool()
                    self.graph_runners[virtual_engine][batch_size] = (
                        graph_runner)
1324
1325
1326
1327

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
1328
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
1329

1330
1331
1332
1333
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1334

1335
1336
1337
1338
1339
1340
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
    """
    GPU model runner with sampling step.
    """
    _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
        ModelInputForGPUWithSamplingMetadata)
1341
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
1342
1343
1344
1345
1346

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForGPUWithSamplingMetadata:
1347
        model_input = \
1348
1349
1350
            ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
1351
1352
            )
        return model_input
1353
1354
1355
1356

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
1357
        virtual_engine: int = 0,
1358
        finished_requests_ids: Optional[List[str]] = None,
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
    ) -> ModelInputForGPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
1374
            seq_group_metadata_list, finished_requests_ids)
1375
1376
1377
1378
1379
1380
        if get_pp_group().is_last_rank:
            # Sampling metadata is only required for the final pp group
            generators = self.get_generators(finished_requests_ids)
            sampling_metadata = SamplingMetadata.prepare(
                seq_group_metadata_list, model_input.seq_lens,
                model_input.query_lens, self.device, self.pin_memory,
1381
                generators, self.sampling_metadata_cache)
1382
1383
        else:
            sampling_metadata = None
1384
1385
1386
1387
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
1388
1389
                                   is_prompt=is_prompt,
                                   virtual_engine=virtual_engine)
1390
1391
1392
1393
1394
1395

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
1396
        intermediate_tensors: Optional[IntermediateTensors] = None,
1397
        num_steps: int = 1,
1398
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
1399
1400
1401
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in ModelRunner")

1402
1403
1404
1405
1406
1407
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

1408
1409
1410
1411
1412
1413
1414
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

1415
        self.attn_state.begin_forward(model_input)
1416

1417
1418
1419
1420
        # Currently cuda graph is only supported by the decode phase.
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
1421
1422
1423
        # TODO(andoorve): We can remove this once all
        # virtual engines share the same kv cache.
        virtual_engine = model_input.virtual_engine
1424
1425
1426
        if prefill_meta is None and decode_meta.use_cuda_graph:
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
1427
1428
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
1429
1430
1431
1432
        else:
            model_executable = self.model

        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
Mor Zusman's avatar
Mor Zusman committed
1433
1434
1435
1436
        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
        } if self.has_seqlen_agnostic else {}
1437
1438
1439
1440
1441
1442
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_start = torch.cuda.Event(enable_timing=True)
            model_forward_end = torch.cuda.Event(enable_timing=True)
            model_forward_start.record()

1443
        hidden_or_intermediate_states = model_executable(
1444
1445
1446
1447
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            kv_caches=kv_caches,
            attn_metadata=model_input.attn_metadata,
1448
            intermediate_tensors=intermediate_tensors,
1449
1450
            **MultiModalInputs.as_kwargs(multi_modal_kwargs,
                                         device=self.device),
Mor Zusman's avatar
Mor Zusman committed
1451
            **seqlen_agnostic_kwargs)
1452

1453
1454
1455
1456
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_end.record()

1457
1458
        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
            if (self.is_driver_worker
                    and hidden_or_intermediate_states is not None
                    and isinstance(hidden_or_intermediate_states,
                                   IntermediateTensors)
                    and self.observability_config is not None
                    and self.observability_config.collect_model_forward_time):
                model_forward_end.synchronize()
                model_forward_time = model_forward_start.elapsed_time(
                    model_forward_end)
                orig_model_forward_time = 0.0
                if intermediate_tensors is not None:
                    orig_model_forward_time = intermediate_tensors.tensors.get(
                        "model_forward_time", torch.tensor(0.0)).item()
                hidden_or_intermediate_states.tensors["model_forward_time"] = (
                    torch.tensor(model_forward_time + orig_model_forward_time))
1474
1475
1476
            return hidden_or_intermediate_states

        logits = self.model.compute_logits(hidden_or_intermediate_states,
1477
1478
1479
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
1480
            return []
1481

1482
1483
        if model_input.async_callback is not None:
            model_input.async_callback()
1484

1485
1486
1487
1488
1489
        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )
1490
1491
1492
1493
1494
1495
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time
                and output is not None):
            model_forward_end.synchronize()
            model_forward_time = model_forward_start.elapsed_time(
                model_forward_end)
1496
1497
1498
1499
            orig_model_forward_time = 0.0
            if intermediate_tensors is not None:
                orig_model_forward_time = intermediate_tensors.tensors.get(
                    "model_forward_time", torch.tensor(0.0)).item()
1500
1501
1502
1503
            # If there are multiple workers, we are still tracking the latency
            # from the start time of the driver worker to the end time of the
            # driver worker. The model forward time will then end up covering
            # the communication time as well.
1504
1505
            output.model_forward_time = (orig_model_forward_time +
                                         model_forward_time)
1506
1507
1508

        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
1509
1510
            assert model_input.sampling_metadata is not None
            indices = model_input.sampling_metadata.selected_token_indices
1511
            if model_input.is_prompt:
1512
1513
                hidden_states = hidden_or_intermediate_states.index_select(
                    0, indices)
1514
                output.prefill_hidden_states = hidden_or_intermediate_states
1515
            elif decode_meta.use_cuda_graph:
1516
1517
1518
                hidden_states = hidden_or_intermediate_states[:len(indices)]
            else:
                hidden_states = hidden_or_intermediate_states
1519

1520
1521
            output.hidden_states = hidden_states

1522
        return [output]
1523
1524


1525
1526
class CUDAGraphRunner:

1527
1528
    def __init__(self, model: nn.Module, backend_name: str,
                 attn_state: AttentionState):
1529
        self.model = model
1530
        self.backend_name = backend_name
1531
        self.attn_state = attn_state
1532

1533
1534
1535
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1536
1537
1538
1539
1540
1541
1542
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1543
1544
1545
1546
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1547
1548
1549
        hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
                                                      torch.Tensor]],
        intermediate_inputs: Optional[IntermediateTensors],
1550
1551
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1552
1553
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
1554
        **kwargs,
1555
    ) -> Union[torch.Tensor, IntermediateTensors]:
1556
        assert self._graph is None
1557
        # Run the model a few times without capturing the graph.
1558
1559
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1560
1561
1562
        # Note one iteration is not enough for torch.jit.script
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
1563
1564
1565
1566
1567
                input_ids=input_ids,
                positions=positions,
                kv_caches=kv_caches,
                attn_metadata=attn_metadata,
                intermediate_tensors=intermediate_inputs,
1568
1569
                **kwargs,
            )
1570
1571
1572
1573
1574
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
1575
            output_hidden_or_intermediate_states = self.model(
1576
1577
1578
1579
1580
                input_ids=input_ids,
                positions=positions,
                kv_caches=kv_caches,
                attn_metadata=attn_metadata,
                intermediate_tensors=intermediate_inputs,
1581
                **kwargs,
1582
            )
1583
1584
1585
1586
1587
1588
1589
1590
            if hidden_or_intermediate_states is not None:
                if get_pp_group().is_last_rank:
                    hidden_or_intermediate_states.copy_(
                        output_hidden_or_intermediate_states)
                else:
                    for key in hidden_or_intermediate_states.tensors:
                        hidden_or_intermediate_states[key].copy_(
                            output_hidden_or_intermediate_states[key])
1591
            else:
1592
1593
1594
1595
                hidden_or_intermediate_states = (
                    output_hidden_or_intermediate_states)

            del output_hidden_or_intermediate_states
1596
1597
1598
            # make sure `output_hidden_states` is deleted
            # in the graph's memory pool
            gc.collect()
1599
1600
1601
        torch.cuda.synchronize()

        # Save the input and output buffers.
1602
1603
1604
1605
1606
1607
1608
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
            **self.attn_state.get_graph_input_buffers(attn_metadata),
            **kwargs,
        }
1609
1610
1611
1612
1613
1614
1615
1616
1617
        if intermediate_inputs is not None:
            self.input_buffers.update(intermediate_inputs.tensors)
        if get_pp_group().is_last_rank:
            self.output_buffers = {
                "hidden_states": hidden_or_intermediate_states
            }
        else:
            self.output_buffers = hidden_or_intermediate_states
        return hidden_or_intermediate_states
1618
1619
1620
1621
1622

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1623
1624
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1625
        intermediate_tensors: Optional[IntermediateTensors],
1626
        **kwargs,
1627
1628
1629
1630
1631
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1632
1633
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1634
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1635
                                                 non_blocking=True)
1636
1637
        self.attn_state.prepare_graph_input_buffers(self.input_buffers,
                                                    attn_metadata)
Mor Zusman's avatar
Mor Zusman committed
1638
1639
1640
        if "seqlen_agnostic_capture_inputs" in self.input_buffers:
            self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
                                                      **kwargs)
1641
1642
1643
1644
1645

        if "previous_hidden_states" in self.input_buffers:
            self.input_buffers["previous_hidden_states"].copy_(
                kwargs["previous_hidden_states"], non_blocking=True)

1646
1647
        if intermediate_tensors is not None:
            for key in intermediate_tensors.tensors:
1648
                if key != "model_execute_time" and key != "model_forward_time":
1649
1650
                    self.input_buffers[key].copy_(intermediate_tensors[key],
                                                  non_blocking=True)
1651
1652
1653
        # Run the graph.
        self.graph.replay()
        # Return the output tensor.
1654
1655
1656
1657
        if get_pp_group().is_last_rank:
            return self.output_buffers["hidden_states"]

        return self.output_buffers
1658
1659
1660
1661

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1662

1663
def _get_graph_batch_size(batch_size: int) -> int:
1664
1665
1666
1667
1668
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1669
1670
1671
1672
1673
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1674
1675
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)