Unverified Commit 11d3976b authored by zhrrr's avatar zhrrr Committed by GitHub
Browse files

[Model Runner V2] support piecewise & mixed cudagraph (#32771)


Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
parent 40da9625
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable from collections.abc import Callable
from typing import Any from typing import Any
import numpy as np import numpy as np
...@@ -11,7 +11,8 @@ from tqdm import tqdm ...@@ -11,7 +11,8 @@ from tqdm import tqdm
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
...@@ -34,14 +35,27 @@ class CudaGraphManager: ...@@ -34,14 +35,27 @@ class CudaGraphManager:
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.uniform_decode_query_len = 1
spec_config = vllm_config.speculative_config
if spec_config is not None:
self.uniform_decode_query_len += spec_config.num_speculative_tokens
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes(
use_uniform_decode_cudagraph = (
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.cudagraph_mode.separate_routine()
)
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes, self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs, self.max_num_reqs,
self.max_num_tokens, self.max_num_tokens,
self.cudagraph_mode, self.cudagraph_mode,
self.uniform_decode_query_len,
use_uniform_decode_cudagraph,
) )
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
...@@ -54,20 +68,16 @@ class CudaGraphManager: ...@@ -54,20 +68,16 @@ class CudaGraphManager:
return len(self.cudagraph_sizes) > 0 return len(self.cudagraph_sizes) > 0
def get_cudagraph_size( def get_cudagraph_size(
self, self, num_tokens: int, uniform_decode: bool = False
num_tokens_after_padding: int,
num_tokens_per_request: Iterable[int],
) -> int | None: ) -> int | None:
return get_cudagraph_size( if uniform_decode and self.uniform_decode_cudagraph_sizes:
num_tokens_after_padding, return self.uniform_decode_cudagraph_sizes.get(num_tokens)
num_tokens_per_request, return self.cudagraph_sizes.get(num_tokens)
self.cudagraph_sizes,
self.cudagraph_mode,
)
def capture_graph( def capture_graph(
self, self,
num_tokens: int, num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module, model: nn.Module,
input_buffers: InputBuffers, input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None, mrope_positions: torch.Tensor | None,
...@@ -75,8 +85,25 @@ class CudaGraphManager: ...@@ -75,8 +85,25 @@ class CudaGraphManager:
block_tables: BlockTables, block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
) -> None: ) -> None:
num_reqs = min(num_tokens, self.max_num_reqs) # select and check capture function
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
# prepare inputs
if uniform_decode:
num_reqs = min(
cdiv(num_tokens, self.uniform_decode_query_len),
self.max_num_reqs,
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens] input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens] positions = input_buffers.positions[:num_tokens]
if self.uses_mrope: if self.uses_mrope:
...@@ -92,6 +119,9 @@ class CudaGraphManager: ...@@ -92,6 +119,9 @@ class CudaGraphManager:
attn_metadata_builders, attn_metadata_builders,
self.max_model_len, self.max_model_len,
kv_cache_config, kv_cache_config,
uniform_decode_query_len=(
self.uniform_decode_query_len if uniform_decode else 0
),
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
...@@ -112,13 +142,40 @@ class CudaGraphManager: ...@@ -112,13 +142,40 @@ class CudaGraphManager:
if self.hidden_states is None: if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states) self.hidden_states = torch.empty_like(hidden_states)
capture_fn(
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
has_lora=has_lora,
)
def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
assert attn_metadata is not None
# Capture the graph. # Capture the graph.
assert num_tokens not in self.graphs assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with ( with (
set_forward_context( set_forward_context(
attn_metadata, attn_metadata=attn_metadata,
self.vllm_config, vllm_config=self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
...@@ -131,9 +188,44 @@ class CudaGraphManager: ...@@ -131,9 +188,44 @@ class CudaGraphManager:
positions=positions, positions=positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states self.hidden_states[:num_tokens] = hidden_states
self.graphs[num_tokens] = graph self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
with set_forward_context(
attn_metadata=None, # piecewise no need attn_metadata
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states
@torch.inference_mode() @torch.inference_mode()
def capture( def capture(
self, self,
...@@ -144,11 +236,11 @@ class CudaGraphManager: ...@@ -144,11 +236,11 @@ class CudaGraphManager:
block_tables: BlockTables, block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
has_lora: bool = False,
) -> None: ) -> None:
capture_graphs( common_kwargs = dict(
self.cudagraph_sizes, device=self.device,
self.device, capture_fn=self.capture_graph,
self.capture_graph,
model=model, model=model,
input_buffers=input_buffers, input_buffers=input_buffers,
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
...@@ -156,10 +248,50 @@ class CudaGraphManager: ...@@ -156,10 +248,50 @@ class CudaGraphManager:
block_tables=block_tables, block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders, attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
has_lora=has_lora,
) )
def run(self, num_tokens: int) -> torch.Tensor: # Phase 1: Capture for mixed prefill-decode batches if needed.
assert num_tokens in self.graphs mixed_mode = self.cudagraph_mode.mixed_mode()
if mixed_mode != CUDAGraphMode.NONE:
capture_graphs(
cudagraph_sizes=self.cudagraph_sizes,
capture_cudagraph_mode=mixed_mode,
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
uniform_decode=False,
**common_kwargs,
)
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if self.uniform_decode_cudagraph_sizes:
capture_graphs(
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
capture_cudagraph_mode=CUDAGraphMode.FULL,
desc="Capturing CUDA graphs (decode, FULL)",
uniform_decode=True,
**common_kwargs,
)
def get_cudagraph_runtime_mode(
self, num_reqs: int, num_tokens: int, max_query_len: int
) -> tuple[CUDAGraphMode, int | None]:
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_tokens == max_query_len * num_reqs
)
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
elif is_uniform_decode:
cudagraph_mode = self.cudagraph_mode.decode_mode()
else:
cudagraph_mode = self.cudagraph_mode.mixed_mode()
return cudagraph_mode, cudagraph_size
def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
self.graphs[num_tokens].replay() self.graphs[num_tokens].replay()
assert self.hidden_states is not None assert self.hidden_states is not None
return self.hidden_states[:num_tokens] return self.hidden_states[:num_tokens]
...@@ -170,22 +302,18 @@ def get_cudagraph_sizes( ...@@ -170,22 +302,18 @@ def get_cudagraph_sizes(
max_num_reqs: int, max_num_reqs: int,
max_num_tokens: int, max_num_tokens: int,
cudagraph_mode: CUDAGraphMode, cudagraph_mode: CUDAGraphMode,
) -> dict[int, int]: uniform_decode_query_len: int = 1,
if not cudagraph_mode.has_full_cudagraphs(): uniform_decode_cudagraph: bool = False,
return {} ) -> tuple[dict[int, int], dict[int, int]]:
# Support both FULL and PIECEWISE cudagraph modes
if cudagraph_mode == CUDAGraphMode.NONE:
return {}, {}
if not capture_sizes: if not capture_sizes:
return {} return {}, {}
capture_sizes = sorted(capture_sizes) capture_sizes = sorted(capture_sizes)
# Limit the capture sizes to the max number of requests or tokens.
upper_bound = (
max_num_reqs
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
else max_num_tokens
)
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
if not capture_sizes: if not capture_sizes:
return {} return {}, {}
cudagraph_sizes: dict[int, int] = {} cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1): for i in range(1, capture_sizes[-1] + 1):
...@@ -193,45 +321,34 @@ def get_cudagraph_sizes( ...@@ -193,45 +321,34 @@ def get_cudagraph_sizes(
if i <= x: if i <= x:
cudagraph_sizes[i] = x cudagraph_sizes[i] = x
break break
return cudagraph_sizes
def get_cudagraph_size(
num_tokens_after_dp_padding: int,
num_tokens_per_request: Iterable[int],
cudagraph_sizes: dict[int, int],
cudagraph_mode: CUDAGraphMode,
) -> int | None:
if not cudagraph_mode.has_full_cudagraphs():
# No full CUDA graph is used.
return None
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
if size is None:
# No CUDA graph for this size.
return None
is_mixed = any(x > 1 for x in num_tokens_per_request) uniform_decode_cudagraph_sizes: dict[int, int] = {}
if is_mixed and cudagraph_mode.mixed_mode() != CUDAGraphMode.FULL: if uniform_decode_cudagraph:
# Prefill is included, and this mode doesn't use CUDA graph for it. max_num_tokens = max_num_reqs * uniform_decode_query_len
return None uniform_decode_cudagraph_sizes = {
return size k: v
for k, v in cudagraph_sizes.items()
if v <= max_num_tokens and v >= uniform_decode_query_len
}
return cudagraph_sizes, uniform_decode_cudagraph_sizes
def capture_graphs( def capture_graphs(
cudagraph_sizes: dict[int, int], cudagraph_sizes: dict[int, int],
device: torch.device, device: torch.device,
capture_fn: Callable, capture_fn: Callable,
capture_cudagraph_mode: CUDAGraphMode,
desc: str = "Capturing CUDA graphs",
**capture_kwargs, **capture_kwargs,
) -> None: ) -> None:
# Capture larger graphs first. # Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True) sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank(): if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
with graph_capture(device=device): with graph_capture(device=device):
for size in sizes_to_capture: for size in sizes_to_capture:
capture_fn(size, **capture_kwargs) capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
def prepare_inputs_to_capture( def prepare_inputs_to_capture(
...@@ -242,8 +359,12 @@ def prepare_inputs_to_capture( ...@@ -242,8 +359,12 @@ def prepare_inputs_to_capture(
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int, max_model_len: int,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
num_tokens_per_req = num_tokens // num_reqs if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens query_start_loc_np[-1] = num_tokens
......
...@@ -13,48 +13,65 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N ...@@ -13,48 +13,65 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
def get_batch_metadata_across_dp( def get_batch_metadata_across_dp(
num_tokens: int, cudagraph_size: int, dp_size: int, dp_rank: int num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]: cudagraph_size: int,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert dp_size > 1 assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization. # Use CPU group to avoid CPU-GPU synchronization.
group = get_dp_group().cpu_group group = get_dp_group().cpu_group
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu") tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size tensor[1][dp_rank] = cudagraph_size
tensor[2][dp_rank] = cudagraph_runtime_mode
dist.all_reduce(tensor, group=group) dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1] return tensor[0], tensor[1], tensor[2]
def get_cudagraph_and_dp_padding( def get_cudagraph_and_dp_padding(
num_tokens: int, cudagraph_size: int | None, dp_size: int, dp_rank: int num_tokens: int,
) -> tuple[bool, int, torch.Tensor | None]: cudagraph_size: int | None,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
if dp_size == 1: if dp_size == 1:
if cudagraph_size is not None: if cudagraph_size is not None:
return True, cudagraph_size, None return cudagraph_size, None, cudagraph_runtime_mode
else: else:
return False, num_tokens, None return num_tokens, None, cudagraph_runtime_mode
# Convert None to -1 for sync (indicates no cudagraph available)
if num_tokens == 0: if num_tokens == 0:
cudagraph_size = 0 cudagraph_size = 0
elif cudagraph_size is None: elif cudagraph_size is None:
cudagraph_size = -1 cudagraph_size = -1
num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
num_tokens, cudagraph_size, dp_size, dp_rank num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
get_batch_metadata_across_dp(
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
)
) )
if torch.all(num_tokens_across_dp == 0).item(): if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run. # All ranks have zero tokens to run.
return False, 0, None return 0, None, 0
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
if torch.all(cudagraph_size_across_dp != -1).item(): if synced_cudagraph_mode != 0 and all_have_cudagraph:
# All ranks use CUDA graph or have zero tokens. # All ranks use cudagraph. Pad to max cudagraph_size.
# Use CUDA graph for all ranks.
# Pad all ranks to the maximum CUDA graph size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item()) max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size num_tokens_across_dp[:] = max_cudagraph_size
return True, max_cudagraph_size, num_tokens_across_dp return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
else: else:
# Some ranks do not use CUDA graph. Use eager mode for all ranks. # Fall back to eager mode (no cudagraph).
# No padding is needed except for ranks that have no tokens to run. # Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode = 0
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1) num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item()) num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return False, num_tokens_after_padding, num_tokens_across_dp return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
...@@ -15,7 +15,7 @@ from vllm.distributed.parallel_state import ( ...@@ -15,7 +15,7 @@ from vllm.distributed.parallel_state import (
get_pp_group, get_pp_group,
prepare_communication_buffer_for_model, prepare_communication_buffer_for_model,
) )
from vllm.forward_context import set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -140,7 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -140,7 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.do_spec_decode = False self.do_spec_decode = False
self.num_speculative_steps = 0 self.num_speculative_steps = 0
self.speculator = None self.speculator = None
self.req_states = RequestState( self.req_states = RequestState(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
...@@ -458,6 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -458,6 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables=self.block_tables, block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders, attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
has_lora=self.lora_config is not None,
) )
if self.do_spec_decode: if self.do_spec_decode:
self.speculator.capture_model() self.speculator.capture_model()
...@@ -884,19 +884,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -884,19 +884,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
empty_output = self.kv_connector.no_forward(scheduler_output) empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output return empty_output
# Get the CUDA graph size. None means no CUDA graph is used. # Get local cudagraph mode and size.
cudagraph_size = self.cudagraph_manager.get_cudagraph_size( local_cudagraph_mode, local_cudagraph_size = (
scheduler_output.total_num_scheduled_tokens, self.cudagraph_manager.get_cudagraph_runtime_mode(
scheduler_output.num_scheduled_tokens.values(), num_reqs=len(scheduler_output.num_scheduled_tokens),
num_tokens=scheduler_output.total_num_scheduled_tokens,
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
)
) )
use_cudagraph, num_tokens_after_padding, num_tokens_across_dp = (
# DP sync: num_tokens + cudagraph_size + cudagraph_mode
num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding( get_cudagraph_and_dp_padding(
scheduler_output.total_num_scheduled_tokens, scheduler_output.total_num_scheduled_tokens,
cudagraph_size, local_cudagraph_size,
local_cudagraph_mode.value,
self.parallel_config.data_parallel_size, self.parallel_config.data_parallel_size,
self.parallel_config.data_parallel_rank, self.parallel_config.data_parallel_rank,
) )
) )
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
if num_tokens_after_padding == 0: if num_tokens_after_padding == 0:
# All DP ranks have zero tokens to run. # All DP ranks have zero tokens to run.
empty_output = self.kv_connector.no_forward(scheduler_output) empty_output = self.kv_connector.no_forward(scheduler_output)
...@@ -946,16 +953,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -946,16 +953,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# FIXME(woosuk): Fix warmup for LoRA. # FIXME(woosuk): Fix warmup for LoRA.
# Run model. # Run model.
if use_cudagraph: if cudagraph_runtime_mode == CUDAGraphMode.FULL:
# Run CUDA graph. # Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors, # NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers. # because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.cudagraph_manager.run( hidden_states = self.cudagraph_manager.run_fullgraph(
input_batch.num_tokens_after_padding input_batch.num_tokens_after_padding
) )
else: else:
# Run PyTorch model in eager mode. # For piecewise and eager mode, just call model().
positions = input_batch.positions positions = input_batch.positions
if self.uses_mrope: if self.uses_mrope:
assert input_batch.mrope_positions is not None assert input_batch.mrope_positions is not None
...@@ -970,13 +977,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -970,13 +977,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds = None inputs_embeds = None
assert intermediate_tensors is not None assert intermediate_tensors is not None
batch_descriptor = BatchDescriptor(
num_tokens=input_batch.num_tokens_after_padding,
has_lora=self.lora_config is not None,
)
with set_forward_context( with set_forward_context(
input_batch.attn_metadata, input_batch.attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding, num_tokens=input_batch.num_tokens_after_padding,
# TODO(woosuk): Support piecewise CUDA graph. cudagraph_runtime_mode=cudagraph_runtime_mode,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=input_batch.slot_mappings, slot_mapping=input_batch.slot_mappings,
): ):
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
......
...@@ -7,7 +7,7 @@ import torch.nn as nn ...@@ -7,7 +7,7 @@ import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
...@@ -103,14 +103,17 @@ class EagleSpeculator: ...@@ -103,14 +103,17 @@ class EagleSpeculator:
attn_metadata: dict[str, Any] | None, attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None, slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
with set_forward_context( with set_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
): ):
ret_hidden_states = self.model( ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens], input_ids=self.input_buffers.input_ids[:num_tokens],
...@@ -127,9 +130,11 @@ class EagleSpeculator: ...@@ -127,9 +130,11 @@ class EagleSpeculator:
def generate_draft( def generate_draft(
self, self,
num_reqs: int, num_reqs: int,
num_tokens_padded: int,
attn_metadata: dict[str, Any], attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor], slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None: ) -> None:
pos = self.input_buffers.positions[:num_reqs] pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
...@@ -137,8 +142,14 @@ class EagleSpeculator: ...@@ -137,8 +142,14 @@ class EagleSpeculator:
for step in range(1, self.num_speculative_steps): for step in range(1, self.num_speculative_steps):
# Run the eagle model. # Run the eagle model.
last_hidden_states, hidden_states = self.run_model( last_hidden_states, hidden_states = self.run_model(
num_reqs, attn_metadata, slot_mappings, num_tokens_across_dp num_tokens_padded,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cudagraph_runtime_mode,
) )
last_hidden_states = last_hidden_states[:num_reqs]
hidden_states = hidden_states[:num_reqs]
logits = self.model.compute_logits(last_hidden_states) logits = self.model.compute_logits(last_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
...@@ -283,12 +294,14 @@ class EagleSpeculator: ...@@ -283,12 +294,14 @@ class EagleSpeculator:
) )
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None: cudagraph_mode = self.cudagraph_manager.cudagraph_mode
# Run CUDA graph. if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
self.cudagraph_manager.run(cudagraph_size) # Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size)
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
# Run eager mode. # Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
query_start_loc_cpu = torch.arange( query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu" num_reqs + 1, dtype=torch.int32, device="cpu"
) )
...@@ -312,8 +325,13 @@ class EagleSpeculator: ...@@ -312,8 +325,13 @@ class EagleSpeculator:
slot_mappings, self.kv_cache_config slot_mappings, self.kv_cache_config
) )
self.generate_draft( self.generate_draft(
num_reqs, attn_metadata, slot_mappings_by_layer, num_tokens_across_dp=None num_reqs,
) # FIXME num_tokens_padded,
attn_metadata,
slot_mappings_by_layer,
num_tokens_across_dp=None, # FIXME
cudagraph_runtime_mode=cudagraph_mode,
)
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from typing import Any
import torch import torch
...@@ -31,16 +32,17 @@ class EagleCudaGraphManager: ...@@ -31,16 +32,17 @@ class EagleCudaGraphManager:
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode # NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
if self.cudagraph_mode == CUDAGraphMode.FULL: self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
self.cudagraph_sizes = get_cudagraph_sizes( # only need to capture uniform decode cudagraph sizes (the 2nd return value)
_, self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes, self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs, self.max_num_reqs,
self.max_num_tokens, self.max_num_tokens,
self.cudagraph_mode, self.cudagraph_mode,
uniform_decode_query_len=1,
uniform_decode_cudagraph=True,
) )
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
...@@ -54,12 +56,21 @@ class EagleCudaGraphManager: ...@@ -54,12 +56,21 @@ class EagleCudaGraphManager:
def capture_graph( def capture_graph(
self, self,
num_tokens: int, num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable, generate_fn: Callable,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> None: ) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
num_reqs = min(num_tokens, self.max_num_reqs) num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata, slot_mappings = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
...@@ -69,19 +80,70 @@ class EagleCudaGraphManager: ...@@ -69,19 +80,70 @@ class EagleCudaGraphManager:
attn_metadata_builders, attn_metadata_builders,
self.max_model_len, self.max_model_len,
kv_cache_config, kv_cache_config,
uniform_decode_query_len=1,
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up. # Warm up.
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp) generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Capture the graph. # Capture the graph.
capture_fn(
num_reqs=num_reqs,
num_tokens=num_tokens,
generate_fn=generate_fn,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
def _capture_full_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
assert num_tokens not in self.graphs assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool): with torch.cuda.graph(graph, self.pool):
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp) generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
self.graphs[num_tokens] = graph self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.PIECEWISE,
)
@torch.inference_mode() @torch.inference_mode()
def capture( def capture(
self, self,
...@@ -91,10 +153,15 @@ class EagleCudaGraphManager: ...@@ -91,10 +153,15 @@ class EagleCudaGraphManager:
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> None: ) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
return
capture_graphs( capture_graphs(
self.cudagraph_sizes, self.cudagraph_sizes,
self.device, self.device,
self.capture_graph, self.capture_graph,
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn, generate_fn=generate_fn,
input_buffers=input_buffers, input_buffers=input_buffers,
block_tables=block_tables, block_tables=block_tables,
...@@ -102,6 +169,6 @@ class EagleCudaGraphManager: ...@@ -102,6 +169,6 @@ class EagleCudaGraphManager:
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
) )
def run(self, num_tokens: int) -> None: def run_fullgraph(self, num_tokens: int) -> None:
assert num_tokens in self.graphs assert num_tokens in self.graphs
self.graphs[num_tokens].replay() self.graphs[num_tokens].replay()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment