# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from unittest.mock import patch import numpy as np import torch import torch.nn as nn from tqdm import tqdm from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.forward_context import set_forward_context from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import InputBuffers class CudaGraphManager: def __init__( self, vllm_config: VllmConfig, device: torch.device, ): self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config self.device = device self.max_model_len = vllm_config.model_config.max_model_len self.max_num_reqs = self.scheduler_config.max_num_seqs self.dp_size = vllm_config.parallel_config.data_parallel_size self.compilation_config = vllm_config.compilation_config assert self.compilation_config is not None if self.compilation_config.cudagraph_mode is None: self.cudagraph_mode = CUDAGraphMode.NONE else: self.cudagraph_mode = self.compilation_config.cudagraph_mode if self.compilation_config.cudagraph_capture_sizes is not None: cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) # Limit the cudagraph sizes to the max decode batch size. self.cudagraph_sizes = [ x for x in cudagraph_sizes if x <= self.max_num_reqs ] else: self.cudagraph_sizes = [] self.padded_sizes = self._init_padded_sizes() self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.pool = torch.cuda.graph_pool_handle() self.hidden_states: torch.Tensor | None = None def _init_padded_sizes(self) -> dict[int, int]: if not self.cudagraph_mode.has_full_cudagraphs(): # Full cuda graphs are not used. return {} if not self.cudagraph_sizes: return {} padded_sizes: dict[int, int] = {} for i in range(1, self.cudagraph_sizes[-1] + 1): for x in self.cudagraph_sizes: if i <= x: padded_sizes[i] = x break return padded_sizes def needs_capture(self) -> bool: return len(self.padded_sizes) > 0 def get_cudagraph_size( self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int, ) -> int | None: if not self.cudagraph_mode.has_full_cudagraphs(): return None if self.cudagraph_mode != CUDAGraphMode.FULL: # TODO(woosuk): Support uniform decode with multiple tokens (spec decoding). all_decode = all( x == 1 for x in scheduler_output.num_scheduled_tokens.values() ) if not all_decode: # Prefill is included. return None return self.padded_sizes.get(num_tokens_after_padding) def capture_graph( self, batch_size: int, model: nn.Module, input_buffers: InputBuffers, block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, ) -> None: assert batch_size not in self.graphs # Prepare dummy inputs. input_ids = input_buffers.input_ids.gpu[:batch_size] positions = input_buffers.positions[:batch_size] input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1) input_buffers.query_start_loc.np[batch_size:] = batch_size input_buffers.query_start_loc.copy_to_gpu() input_buffers.seq_lens[:batch_size] = self.max_model_len input_buffers.seq_lens[batch_size:] = 0 input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables] slot_mappings = block_tables.slot_mappings[:, :batch_size] attn_metadata = build_attn_metadata( attn_metadata_builders=attn_metadata_builders, num_reqs=batch_size, num_tokens=batch_size, query_start_loc=input_buffers.query_start_loc, seq_lens=input_buffers.seq_lens, seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32), num_computed_tokens_cpu=None, # FIXME block_tables=input_block_tables, slot_mappings=slot_mappings, kv_cache_config=kv_cache_config, ) if self.dp_size > 1: num_tokens_across_dp = torch.full( (self.dp_size,), batch_size, dtype=torch.int32, device="cpu", ) else: num_tokens_across_dp = None # Warm up. with set_forward_context( attn_metadata, self.vllm_config, num_tokens=batch_size, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, ): hidden_states = model( input_ids=input_ids, positions=positions, ) if self.hidden_states is None: self.hidden_states = torch.empty_like(hidden_states) # Capture the graph. graph = torch.cuda.CUDAGraph() with ( patch("torch.cuda.empty_cache", lambda: None), set_forward_context( attn_metadata, self.vllm_config, num_tokens=batch_size, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, ), torch.cuda.graph(graph, self.pool), ): hidden_states = model( input_ids=input_ids, positions=positions, ) self.hidden_states[:batch_size] = hidden_states self.graphs[batch_size] = graph @torch.inference_mode() def capture( self, model: nn.Module, input_buffers: InputBuffers, block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, ) -> None: assert self.needs_capture() # Capture larger graphs first. sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True) if is_global_first_rank(): sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") with graph_capture(device=self.device): for batch_size in sizes_to_capture: self.capture_graph( batch_size, model, input_buffers, block_tables, attn_metadata_builders, kv_cache_config, ) def run(self, batch_size: int) -> torch.Tensor: assert batch_size in self.graphs self.graphs[batch_size].replay() assert self.hidden_states is not None return self.hidden_states[:batch_size]