Commit 71b02b7a authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa full_cuda_graph support

parent ffe9e7db
......@@ -4135,7 +4135,7 @@ class CompilationConfig:
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False."""
full_cuda_graph: bool = False
full_cuda_graph: bool = True
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
......
......@@ -151,7 +151,7 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 or current_platform.is_rocm()
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
......@@ -172,7 +172,8 @@ class FlashAttentionMetadataBuilder(
self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
if not current_platform.is_rocm() and self.use_full_cuda_graph:
if self.use_full_cuda_graph:
if not current_platform.is_rocm():
if not self.aot_schedule:
raise ValueError(
"AoT scheduling is required for full cuda graph.")
......@@ -325,7 +326,7 @@ class FlashAttentionMetadataBuilder(
scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0
if (not current_platform.is_rocm() and self.use_full_cuda_graph
if (self.use_full_cuda_graph
and num_actual_tokens <= self.max_cudagraph_size):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
......
......@@ -2548,7 +2548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.device,
)
if (not current_platform.is_rocm() and self.full_cuda_graph
if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment