Unverified Commit cc06b4e8 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Mamba][Bugfix] Raise on insufficient cache blocks instead of silently capping...


[Mamba][Bugfix] Raise on insufficient cache blocks instead of silently capping cudagraph sizes (#38270)
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 03ac6ca8
...@@ -577,48 +577,6 @@ def test_compile_sizes_padding_validation(): ...@@ -577,48 +577,6 @@ def test_compile_sizes_padding_validation():
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
@pytest.mark.parametrize(
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
[
# Normal capping: sizes filtered to <= num_blocks
(
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
512,
200,
[1, 2, 4, 8, 16, 32, 64, 128],
128,
),
# No capping needed: num_blocks >= max
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
# Exact boundary: num_blocks == max (no capping)
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
# All sizes capped: num_blocks < smallest size
([8, 16, 32], 32, 4, [], 0),
# num_blocks <= 0: early return, no change
([1, 2, 4], 4, 0, [1, 2, 4], 4),
],
)
def test_adjust_cudagraph_sizes_for_mamba_cache(
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
):
"""Test that cudagraph capture sizes are correctly capped to fit
available Mamba cache blocks.
See: https://github.com/vllm-project/vllm/issues/34094
"""
config = CompilationConfig(
cudagraph_capture_sizes=capture_sizes,
max_cudagraph_capture_size=max_size,
cudagraph_mode=CUDAGraphMode.NONE,
)
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
assert config.cudagraph_capture_sizes == expected_sizes
assert config.max_cudagraph_capture_size == expected_max
# Invariant: last element == max_cudagraph_capture_size
if expected_sizes:
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size
def test_inductor_asserts_default_disabled(monkeypatch): def test_inductor_asserts_default_disabled(monkeypatch):
"""Test that inductor runtime asserts are disabled by default """Test that inductor runtime asserts are disabled by default
(INFO logging level) on torch < 2.12.""" (INFO logging level) on torch < 2.12."""
......
...@@ -1191,9 +1191,9 @@ def test_is_uniform_decode() -> None: ...@@ -1191,9 +1191,9 @@ def test_is_uniform_decode() -> None:
current_platform.is_rocm(), current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.", reason="Attention backend FLASHINFER is not supported on ROCm.",
) )
def test_cudagraph_sizes_capped_for_mamba_cache(): def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
"""Test that cudagraph capture sizes are capped to num_blocks for """Test that a ValueError is raised when max_num_seqs exceeds the
hybrid models with Mamba layers. available Mamba cache blocks for hybrid models with FULL cudagraphs.
See: https://github.com/vllm-project/vllm/issues/34094 See: https://github.com/vllm-project/vllm/issues/34094
""" """
...@@ -1284,23 +1284,8 @@ def test_cudagraph_sizes_capped_for_mamba_cache(): ...@@ -1284,23 +1284,8 @@ def test_cudagraph_sizes_capped_for_mamba_cache():
)[0] )[0]
num_blocks = kv_cache_config.num_blocks num_blocks = kv_cache_config.num_blocks
# Set max_cudagraph_capture_size to a value larger than num_blocks # Force max_num_seqs to exceed num_blocks so the check triggers.
# to trigger the Mamba capping logic. runner.max_num_reqs = num_blocks + 100
large_max = num_blocks + 100
compilation_config = vllm_config.compilation_config
compilation_config.max_cudagraph_capture_size = large_max
compilation_config.cudagraph_capture_sizes = [
s for s in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] if s <= large_max
]
with pytest.raises(ValueError, match="max_num_seqs"):
runner.initialize_kv_cache(kv_cache_config) runner.initialize_kv_cache(kv_cache_config)
# After initialization, cudagraph sizes should be capped
assert compilation_config.max_cudagraph_capture_size <= num_blocks
assert all(s <= num_blocks for s in compilation_config.cudagraph_capture_sizes)
# Invariant: last element == max
if compilation_config.cudagraph_capture_sizes:
assert (
compilation_config.cudagraph_capture_sizes[-1]
== compilation_config.max_cudagraph_capture_size
)
...@@ -1279,58 +1279,6 @@ class CompilationConfig: ...@@ -1279,58 +1279,6 @@ class CompilationConfig:
self.max_cudagraph_capture_size = rounded_sizes[-1] self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes self.cudagraph_capture_sizes = rounded_sizes
def adjust_cudagraph_sizes_for_mamba_cache(
self, num_mamba_cache_blocks: int
) -> None:
"""Cap cudagraph capture sizes to available Mamba cache blocks.
For hybrid Mamba/attention models, the Mamba conv_state and
ssm_state tensors have their first dimension equal to num_blocks
(from KVCacheConfig). During CUDA graph capture the decode batch
size equals num_tokens, so capture sizes exceeding num_blocks
would cause out-of-bounds access in Mamba kernels.
See: https://github.com/vllm-project/vllm/issues/34094
"""
if not self.cudagraph_capture_sizes or num_mamba_cache_blocks <= 0:
return
assert self.max_cudagraph_capture_size is not None
if num_mamba_cache_blocks >= self.max_cudagraph_capture_size:
return
capped_sizes = [
s for s in self.cudagraph_capture_sizes if s <= num_mamba_cache_blocks
]
if len(capped_sizes) == 0:
logger.warning(
"No valid cudagraph capture sizes remain after capping "
"to Mamba cache blocks (%d). The smallest capture size "
"was %d. Disabling cudagraph capture. Consider reducing "
"max_num_seqs or increasing available GPU memory.",
num_mamba_cache_blocks,
self.cudagraph_capture_sizes[0],
)
self.cudagraph_capture_sizes = []
self.max_cudagraph_capture_size = 0
return
logger.warning(
"Capping cudagraph capture sizes from max %d to %d to fit "
"Mamba cache blocks (%d blocks available). This limits the "
"maximum batch size that can use CUDA graphs. To increase "
"this limit, reduce max_num_seqs or increase available GPU "
"memory.",
self.max_cudagraph_capture_size,
capped_sizes[-1],
num_mamba_cache_blocks,
)
self.max_cudagraph_capture_size = capped_sizes[-1]
self.cudagraph_capture_sizes = capped_sizes
def get_compile_ranges(self) -> list[Range]: def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config.""" """Get the compile ranges for the compilation config."""
if self.compile_ranges_endpoints is None: if self.compile_ranges_endpoints is None:
......
...@@ -5800,7 +5800,7 @@ class GPUModelRunner( ...@@ -5800,7 +5800,7 @@ class GPUModelRunner(
) )
self.cache_config.num_gpu_blocks_override = saved_override self.cache_config.num_gpu_blocks_override = saved_override
self.initialize_kv_cache(minimal_config) self.initialize_kv_cache(minimal_config, is_profiling=True)
self.cache_config.num_gpu_blocks = minimal_config.num_blocks self.cache_config.num_gpu_blocks = minimal_config.num_blocks
logger.debug("Initialized minimal KV cache for CUDA graph profiling") logger.debug("Initialized minimal KV cache for CUDA graph profiling")
...@@ -6121,7 +6121,11 @@ class GPUModelRunner( ...@@ -6121,7 +6121,11 @@ class GPUModelRunner(
torch.accelerator.synchronize() torch.accelerator.synchronize()
self.maybe_remove_all_loras(self.lora_config) self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(
self,
kv_cache_config: KVCacheConfig,
is_profiling: bool = False,
) -> None:
""" """
Initialize the attention backends and attention metadata builders. Initialize the attention backends and attention metadata builders.
""" """
...@@ -6193,7 +6197,9 @@ class GPUModelRunner( ...@@ -6193,7 +6197,9 @@ class GPUModelRunner(
# Resolve cudagraph_mode before actually initialize metadata_builders # Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode( self._check_and_update_cudagraph_mode(
attention_backend_list, kv_cache_config.kv_cache_groups attention_backend_list,
kv_cache_config.kv_cache_groups,
is_profiling=is_profiling,
) )
# Check if attention backend supports PCP&DCP and related features. # Check if attention backend supports PCP&DCP and related features.
...@@ -6237,6 +6243,7 @@ class GPUModelRunner( ...@@ -6237,6 +6243,7 @@ class GPUModelRunner(
self, self,
attention_backends: list[set[type[AttentionBackend]]], attention_backends: list[set[type[AttentionBackend]]],
kv_cache_groups: list[KVCacheGroupSpec], kv_cache_groups: list[KVCacheGroupSpec],
is_profiling: bool = False,
) -> None: ) -> None:
""" """
Resolve the cudagraph_mode when there are multiple attention Resolve the cudagraph_mode when there are multiple attention
...@@ -6377,20 +6384,28 @@ class GPUModelRunner( ...@@ -6377,20 +6384,28 @@ class GPUModelRunner(
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
) )
# If the model has Mamba layers and cudagraph mode includes FULL # For Mamba models with FULL decode cudagraphs, each decode
# decode, cap cudagraph capture sizes to the number of available # sequence needs one Mamba cache block. The decode cudagraph
# Mamba cache blocks. Each decode request needs one conv_state # dispatcher already caps batch sizes at max_num_seqs, so we just
# cache line, so capture batch sizes cannot exceed num_blocks. # need to verify that enough blocks exist. Raising here instead
# Only FULL decode graphs are affected because PIECEWISE captures # of silently capping cudagraph_capture_sizes avoids unintended
# run GDN/Mamba ops eagerly (prefill path, no causal_conv1d_update). # restrictions on PIECEWISE (prefill) cudagraphs.
# See: https://github.com/vllm-project/vllm/issues/34094 # See: https://github.com/vllm-project/vllm/issues/34094
if cudagraph_mode.has_full_cudagraphs(): if cudagraph_mode.has_full_cudagraphs() and not is_profiling:
has_mamba = any( has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups
) )
if has_mamba and self.kv_cache_config is not None: if has_mamba and self.kv_cache_config is not None:
self.compilation_config.adjust_cudagraph_sizes_for_mamba_cache( num_blocks = self.kv_cache_config.num_blocks
self.kv_cache_config.num_blocks if self.max_num_reqs > num_blocks:
raise ValueError(
f"max_num_seqs ({self.max_num_reqs}) exceeds "
f"available Mamba cache blocks ({num_blocks}). "
f"Each decode sequence requires one Mamba cache "
f"block, so CUDA graph capture cannot proceed. "
f"Please lower max_num_seqs to at most "
f"{num_blocks} or increase "
f"gpu_memory_utilization."
) )
# Trigger cudagraph dispatching keys initialization after # Trigger cudagraph dispatching keys initialization after
...@@ -6752,7 +6767,11 @@ class GPUModelRunner( ...@@ -6752,7 +6767,11 @@ class GPUModelRunner(
else: else:
break break
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(
self,
kv_cache_config: KVCacheConfig,
is_profiling: bool = False,
) -> None:
""" """
Initialize KV cache based on `kv_cache_config`. Initialize KV cache based on `kv_cache_config`.
Args: Args:
...@@ -6764,7 +6783,7 @@ class GPUModelRunner( ...@@ -6764,7 +6783,7 @@ class GPUModelRunner(
self._mamba_copy_bufs = None self._mamba_copy_bufs = None
self.may_add_encoder_only_layers_to_kv_cache_config() self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config) self.initialize_attn_backend(kv_cache_config, is_profiling=is_profiling)
# The kernel block size for all KV cache groups. For example, if # The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention # kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return # backends for that group only supports block_size 64, we will return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment