Commit 8da1c576 authored by zhuwenwen's avatar zhuwenwen
Browse files

support full_cuda_graph

parent f9408aff
...@@ -183,7 +183,7 @@ class FlashAttentionMetadataBuilder( ...@@ -183,7 +183,7 @@ class FlashAttentionMetadataBuilder(
self.max_num_splits = 0 # No upper bound on the number of splits. self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3) self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph: if not current_platform.is_rocm() and self.use_full_cuda_graph:
if not self.aot_schedule: if not self.aot_schedule:
raise ValueError( raise ValueError(
"AoT scheduling is required for full cuda graph.") "AoT scheduling is required for full cuda graph.")
...@@ -361,7 +361,7 @@ class FlashAttentionMetadataBuilder( ...@@ -361,7 +361,7 @@ class FlashAttentionMetadataBuilder(
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
causal=True) causal=True)
if self.use_full_cuda_graph: if not current_platform.is_rocm() and self.use_full_cuda_graph:
assert scheduler_metadata is not None assert scheduler_metadata is not None
n = scheduler_metadata.shape[0] n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata self.scheduler_metadata[:n] = scheduler_metadata
...@@ -373,7 +373,7 @@ class FlashAttentionMetadataBuilder( ...@@ -373,7 +373,7 @@ class FlashAttentionMetadataBuilder(
scheduler_metadata = self.scheduler_metadata[:n] scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0 max_num_splits = 0
if (self.use_full_cuda_graph if (not current_platform.is_rocm() and self.use_full_cuda_graph
and num_actual_tokens <= self.max_cudagraph_size): and num_actual_tokens <= self.max_cudagraph_size):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory # NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits, # usage, because the intermediate buffers of size [num_splits,
......
...@@ -67,6 +67,7 @@ from vllm.v1.utils import bind_kv_cache ...@@ -67,6 +67,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.platforms import current_platform
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
...@@ -2371,7 +2372,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2371,7 +2372,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_table_i, block_table_i,
) )
if (self.full_cuda_graph if (not current_platform.is_rocm() and self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported): and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError( raise ValueError(
f"Full CUDAGraph not supported for " f"Full CUDAGraph not supported for "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment