Commit 513f17a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa full_cuda_graph support

parent cc6f327a
...@@ -4106,7 +4106,7 @@ class CompilationConfig: ...@@ -4106,7 +4106,7 @@ class CompilationConfig:
are always used, it can set this to False. Otherwise, it should are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.""" internally managed buffer. Default is False."""
full_cuda_graph: bool = False full_cuda_graph: bool = True
"""whether to use a full cuda graph for the entire forward pass rather than """whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide flag cannot be used together with splitting_ops. This may provide
......
...@@ -163,7 +163,7 @@ def _get_sliding_window_configs( ...@@ -163,7 +163,7 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 or current_platform.is_rocm()
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable): block_table: BlockTable):
...@@ -183,7 +183,8 @@ class FlashAttentionMetadataBuilder( ...@@ -183,7 +183,8 @@ class FlashAttentionMetadataBuilder(
self.max_num_splits = 0 # No upper bound on the number of splits. self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3) self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph self.use_full_cuda_graph = compilation_config.full_cuda_graph
if not current_platform.is_rocm() and self.use_full_cuda_graph: if self.use_full_cuda_graph:
if not current_platform.is_rocm():
if not self.aot_schedule: if not self.aot_schedule:
raise ValueError( raise ValueError(
"AoT scheduling is required for full cuda graph.") "AoT scheduling is required for full cuda graph.")
...@@ -373,7 +374,7 @@ class FlashAttentionMetadataBuilder( ...@@ -373,7 +374,7 @@ class FlashAttentionMetadataBuilder(
scheduler_metadata = self.scheduler_metadata[:n] scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0 max_num_splits = 0
if (not current_platform.is_rocm() and self.use_full_cuda_graph if (self.use_full_cuda_graph
and num_actual_tokens <= self.max_cudagraph_size): and num_actual_tokens <= self.max_cudagraph_size):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory # NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits, # usage, because the intermediate buffers of size [num_splits,
......
...@@ -2385,7 +2385,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2385,7 +2385,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_table_i, block_table_i,
) )
if (not current_platform.is_rocm() and self.full_cuda_graph if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported): and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError( raise ValueError(
f"Full CUDAGraph not supported for " f"Full CUDAGraph not supported for "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment