Unverified Commit 22613408 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Misc] Remove pad_for_cudagraphs from config (#30143)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 86c69dc5
......@@ -2,14 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from contextlib import nullcontext
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.config import (
CompilationConfig,
CUDAGraphMode,
ParallelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
......@@ -17,6 +23,7 @@ from vllm.utils.torch_utils import (
_is_torch_equal_or_newer,
is_torch_equal,
)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
......@@ -472,6 +479,19 @@ def test_cached_compilation_config(default_vllm_config):
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
def _create_vllm_config_for_validation(
compilation_config: CompilationConfig,
) -> MagicMock:
"""Helper to create a mock VllmConfig for padding validation testing."""
mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig.default_factory(max_num_seqs=8)
mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None
mock_config.lora_config = None
return mock_config
def test_compile_sizes_padding_validation():
"""Test that compile_sizes with values that would be padded raises an error."""
# cudagraph_capture_sizes=[1, 2, 4, 8] means:
......@@ -488,29 +508,39 @@ def test_compile_sizes_padding_validation():
cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8,
compile_sizes=[3],
cudagraph_mode=CUDAGraphMode.FULL,
)
config.post_init_cudagraph_sizes()
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
with pytest.raises(ValueError, match="would be padded to"):
config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8,
compile_sizes=[5],
cudagraph_mode=CUDAGraphMode.FULL,
)
config.post_init_cudagraph_sizes()
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8,
compile_sizes=[1, 2, 4, 8],
cudagraph_mode=CUDAGraphMode.FULL,
)
config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) # Should not raise
config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8,
compile_sizes=["cudagraph_capture_sizes"],
cudagraph_mode=CUDAGraphMode.FULL,
)
config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
......@@ -535,3 +565,5 @@ def test_compile_sizes_padding_validation():
)
config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [3, 5, 7]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
......@@ -9,6 +9,7 @@ from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from ...utils import check_logprobs_close, check_outputs_equal
......@@ -172,7 +173,14 @@ def test_mamba_cache_cg_padding(
tensor dimensions aren't compatible.
"""
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)):
cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
cudagraph_dispatcher.initialize_cudagraph_keys(
vllm_config.compilation_config.cudagraph_mode
)
while (
len(example_prompts)
== cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens
):
example_prompts.append(example_prompts[0])
try:
......
......@@ -61,9 +61,6 @@ def _create_vllm_config(
)
compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)
return mock_config
......@@ -169,6 +166,7 @@ class TestCudagraphDispatcher:
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
......@@ -360,7 +358,7 @@ class TestCudagraphIntegration:
):
full_wrapper(input_1)
rt_mode, key = self.dispatcher.dispatch(desc_1)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
# 1. Capture first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "capture_global"
......@@ -369,7 +367,7 @@ class TestCudagraphIntegration:
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "replay"
rt_mode, key = self.dispatcher.dispatch(desc_2)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
# 3. Capture second shape
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
assert action == "capture_global"
......@@ -381,7 +379,7 @@ class TestCudagraphIntegration:
assert action == "replay"
# 5. Bypass if no key match
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
assert rt_mode == CUDAGraphMode.NONE
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
assert action == "bypass"
......
......@@ -15,7 +15,7 @@ from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.config.utils import Range
from vllm.logger import init_logger
logger = init_logger(__name__)
......
......@@ -11,7 +11,7 @@ import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
......
......@@ -581,15 +581,6 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
)
"""optimization:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_cudagraph_capture_size],
we can optimize it to list[int] for better lookup performance."""
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
"""custom ops that are enabled"""
......@@ -639,7 +630,6 @@ class CompilationConfig:
"debug_dump_path",
"cache_dir",
"local_cache_dir",
"bs_to_padded_graph_size",
"traced_files",
"compilation_time",
"static_forward_context",
......@@ -661,7 +651,6 @@ class CompilationConfig:
"enabled_custom_ops": True,
"disabled_custom_ops": True,
"compilation_time": True,
"bs_to_padded_graph_size": True,
"traced_files": True,
"inductor_compile_config": {
"post_grad_custom_post_pass": True,
......@@ -882,7 +871,6 @@ class CompilationConfig:
"""To complete the initialization after cudagraph related
configs are set. This includes:
- initialize compile_sizes
- pre-compute the mapping bs_to_padded_graph_size
"""
computed_compile_sizes = []
......@@ -906,23 +894,6 @@ class CompilationConfig:
if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
# May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size()
# Validate that compile_sizes won't be changed by padding.
# Only validate when cudagraphs are actually being used.
if self.compile_sizes and self.cudagraph_mode != CUDAGraphMode.NONE:
for size in self.compile_sizes:
if size <= self.max_cudagraph_capture_size:
padded = self.bs_to_padded_graph_size[size]
if padded != size:
raise ValueError(
f"compile_sizes contains {size} which would be "
f"padded to {padded}. All compile_sizes must be "
"values that won't be changed by cudagraph padding. "
"Use values from cudagraph_capture_sizes."
)
def set_splitting_ops_for_v1(
self, all2all_backend: str, data_parallel_size: int = 1
):
......@@ -1134,24 +1105,6 @@ class CompilationConfig:
self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes
# Recompute after adjusting the cudagraph sizes
self.compute_bs_to_padded_graph_size()
def compute_bs_to_padded_graph_size(self):
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
......
......@@ -57,13 +57,47 @@ class CudagraphDispatcher:
)
self.keys_initialized = False
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE
def _compute_bs_to_padded_graph_size(self) -> None:
"""Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes
self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
for end, start in zip(
capture_sizes + [max_size + 1],
[0] + capture_sizes,
):
for bs in range(start, end):
if bs == start:
self._bs_to_padded_graph_size[bs] = start
else:
self._bs_to_padded_graph_size[bs] = end
# Validate that compile_sizes won't be changed by padding.
# Only validate when cudagraphs are actually being used.
if (
self.compilation_config.compile_sizes
and self.cudagraph_mode != CUDAGraphMode.NONE
):
for size in self.compilation_config.compile_sizes:
if size <= self.compilation_config.max_cudagraph_capture_size:
padded = self._bs_to_padded_graph_size[size]
if padded != size:
raise ValueError(
f"compile_sizes contains {size} which would be "
f"padded to {padded}. All compile_sizes must be "
"values that won't be changed by cudagraph padding. "
"Use values from cudagraph_capture_sizes."
)
def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
......@@ -88,12 +122,19 @@ class CudagraphDispatcher:
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1
):
# This should be called only after attention backend is initialized. So we can
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode
# Early exit if cudagraphs are disabled
if cudagraph_mode == CUDAGraphMode.NONE:
self.keys_initialized = True
return
self._compute_bs_to_padded_graph_size()
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
......@@ -143,15 +184,24 @@ class CudagraphDispatcher:
def dispatch(
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
uniform_decode: bool = False,
has_lora: bool = False,
disable_full: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
Given conditions(e.g.,batch descriptor and if using piecewise only),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
Args:
num_tokens: Number of tokens in the batch.
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
"""
if (
not self.keys_initialized
......
......@@ -9,7 +9,6 @@ import torch
import torch.nn as nn
from vllm.config import (
CompilationMode,
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
......@@ -36,6 +35,7 @@ from vllm.v1.attention.backends.tree_attn import (
TreeAttentionMetadataBuilder,
)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
......@@ -100,24 +100,13 @@ class SpecDecodeBaseProposer:
self._get_eagle3_use_aux_hidden_state_from_config()
)
self.use_cuda_graph = False
self.compilation_config = self.vllm_config.compilation_config
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
CUDAGraphMode.PIECEWISE
):
logger.warning(
"Currently the eagle proposer only supports cudagraph_mode "
"PIECEWISE, if you want the drafter to use cuda graphs, "
"please set compilation_config.cudagraph_mode to PIECEWISE "
"or FULL_AND_PIECEWISE"
)
self.use_cuda_graph = (
cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
and not self.speculative_config.enforce_eager
)
# Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
# Keys are initialized later via initialize_cudagraph_keys() called from
# gpu_model_runner._check_and_update_cudagraph_mode after
# adjust_cudagraph_sizes_for_spec_decode is called.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
# persistent buffers for cuda graph
self.input_ids = torch.zeros(
......@@ -234,6 +223,23 @@ class SpecDecodeBaseProposer:
else:
self.positions[:num_tokens] = positions
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
This should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
if (
not self.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
eagle_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def propose(
self,
# [num_tokens]
......@@ -304,16 +310,10 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode = CUDAGraphMode.NONE
if (
self.use_cuda_graph
and num_tokens_dp_padded
<= self.compilation_config.max_cudagraph_capture_size
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = num_tokens_dp_padded
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded
)
num_input_tokens = batch_desc.num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
......@@ -412,16 +412,10 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
)
if (
self.use_cuda_graph
and batch_size_dp_padded
<= self.compilation_config.max_cudagraph_capture_size
):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
input_batch_size = batch_size_dp_padded
cudagraph_runtime_mode = CUDAGraphMode.NONE
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
batch_size_dp_padded
)
input_batch_size = batch_desc.num_tokens
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
......@@ -870,15 +864,10 @@ class SpecDecodeBaseProposer:
self.positions[:num_tokens] = tree_positions.view(-1)
self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
if (
self.use_cuda_graph
and num_tokens <= self.compilation_config.max_cudagraph_capture_size
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = num_tokens
cudagraph_runtime_mode = CUDAGraphMode.NONE
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens
)
num_input_tokens = batch_desc.num_tokens
# Run the model.
with set_forward_context(
per_layer_attn_metadata,
......@@ -1216,9 +1205,6 @@ class SpecDecodeBaseProposer:
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
) -> None:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
for fwd_idx in range(
......@@ -1228,16 +1214,10 @@ class SpecDecodeBaseProposer:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
if (
cudagraphs_enabled
and num_tokens_dp_padded
<= self.compilation_config.max_cudagraph_capture_size
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_tokens_dp_padded
)
else:
num_input_tokens = num_tokens_dp_padded
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded
)
num_input_tokens = batch_desc.num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
......@@ -1246,9 +1226,7 @@ class SpecDecodeBaseProposer:
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
if cudagraphs_enabled
else CUDAGraphMode.NONE,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
if self.supports_mm_inputs:
input_ids = None
......@@ -1340,7 +1318,8 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded=num_tokens_unpadded,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
allow_dp_padding=self.use_cuda_graph,
allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode
!= CUDAGraphMode.NONE,
num_tokens_padded=num_tokens_padded,
uniform_decode=None,
num_scheduled_tokens_per_request=None,
......
......@@ -2139,15 +2139,11 @@ class GPUModelRunner(
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item()
)
if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_logits <= self.cudagraph_batch_sizes[-1]
):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
else:
num_logits_padded = num_logits
# Dispatch for the decoder portion of the model.
_, batch_desc = self.cudagraph_dispatcher.dispatch(
num_logits, disable_full=True
)
num_logits_padded = batch_desc.num_tokens
logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[
:num_logits_padded
]
......@@ -5212,6 +5208,11 @@ class GPUModelRunner(
cudagraph_mode, self.uniform_decode_query_len
)
# Initialize eagle's cudagraph dispatcher if using eagle spec decode.
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.initialize_cudagraph_keys(cudagraph_mode)
def calculate_reorder_batch_threshold(self) -> None:
"""
Choose the minimum reorder batch threshold from all attention groups.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment