"vllm/vscode:/vscode.git/clone" did not exist on "1aa13615103c2ea47e36710a9b2e17dfe1909143"
Unverified Commit 85fee74b authored by fhl2000's avatar fhl2000 Committed by GitHub
Browse files

[Bugfix][CI] Move resolving cudagraph_mode before initializing attn_metadata_builder (#27427)


Signed-off-by: default avatarfhl2000 <63384265+fhl2000@users.noreply.github.com>
parent 8dbe0c52
......@@ -167,7 +167,7 @@ class AttentionCGSupport(enum.Enum):
"""NO CUDA Graphs support"""
```
Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture].
Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code for [this][vllm.v1.worker.gpu_model_runner.GPUModelRunner._check_and_update_cudagraph_mode].
The following table lists backends that support full CUDA Graphs at the time of writing.
......
......@@ -132,6 +132,9 @@ def test_attn_quant(
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
# FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
# CUDAGraphMode.NONE here because it derives an attention backend that
# does not support full cudagraphs
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
......
......@@ -3751,8 +3751,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return 0
else:
self.initialize_cudagraph_capture()
compilation_counter.num_gpu_runner_capture_triggers += 1
......@@ -3926,7 +3924,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> dict[AttentionGroupKey, list[str]]:
) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names
)
......@@ -3955,7 +3953,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_backend, layer_kv_cache_spec
)
attn_backend_layers[key].append(layer_name)
return {attn_backends[k]: v for k, v in attn_backend_layers.items()}
return (
{attn_backends[k]: v for k, v in attn_backend_layers.items()},
set(group_key.attn_backend for group_key in attn_backends.values()),
)
def create_attn_groups(
attn_backends_map: dict[AttentionGroupKey, list[str]],
......@@ -3976,14 +3977,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_groups.append(attn_group)
return attn_groups
attention_backend_maps = []
attention_backend_set: set[type[AttentionBackend]] = set()
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
self.attn_groups.append(create_attn_groups(attn_backends))
attention_backend_maps.append(attn_backends[0])
attention_backend_set.update(attn_backends[1])
# Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(attention_backend_set)
for attn_backends_map in attention_backend_maps:
self.attn_groups.append(create_attn_groups(attn_backends_map))
# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()
def initialize_cudagraph_capture(self) -> None:
def _check_and_update_cudagraph_mode(
self, attention_backends: set[type[AttentionBackend]]
) -> None:
"""
Resolve the cudagraph_mode when there are multiple attention
backends with potential conflicting CUDA graph support.
......@@ -3991,13 +4003,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_mode.
"""
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None
min_cg_backend_name = None
for attn_group in self._attn_group_iterator():
builder = attn_group.get_metadata_builder()
if builder.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder.cudagraph_support
min_cg_builder_name = builder.__class__.__name__
for attn_backend in attention_backends:
builder_cls = attn_backend.get_builder_cls()
if builder_cls.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder_cls.cudagraph_support
min_cg_backend_name = attn_backend.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
# check cudagraph for mixed batch is supported
......@@ -4007,7 +4019,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"with {min_cg_backend_name} backend (support: "
f"{min_cg_support})"
)
if min_cg_support == AttentionCGSupport.NEVER:
......@@ -4038,7 +4050,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"with {min_cg_backend_name} backend (support: "
f"{min_cg_support})"
)
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
......@@ -4072,7 +4084,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_cg_builder_name} (support: {min_cg_support})"
f"{min_cg_backend_name} (support: {min_cg_support})"
)
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=PIECEWISE"
......@@ -4094,14 +4106,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
raise ValueError(
f"CUDAGraphMode.{cudagraph_mode.name} is not "
f"supported with {min_cg_builder_name} backend ("
f"supported with {min_cg_backend_name} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation mode is VLLM_COMPILE"
)
# Trigger cudagraph dispatching keys initialization here (after
# initializing attn backends).
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode, self.uniform_decode_query_len
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment