Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b124e108
Unverified
Commit
b124e108
authored
Jun 03, 2025
by
Woosuk Kwon
Committed by
GitHub
Jun 03, 2025
Browse files
[Bugfix] Fix FA3 full cuda graph correctness (#19106)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
41aa5784
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
10 deletions
+32
-10
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/compile/piecewise/test_full_cudagraph.py
tests/compile/piecewise/test_full_cudagraph.py
+5
-2
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+21
-8
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+5
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
b124e108
...
@@ -320,6 +320,7 @@ steps:
...
@@ -320,6 +320,7 @@ steps:
# these tests need to be separated, cannot combine
# these tests need to be separated, cannot combine
-
pytest -v -s compile/piecewise/test_simple.py
-
pytest -v -s compile/piecewise/test_simple.py
-
pytest -v -s compile/piecewise/test_toy_llama.py
-
pytest -v -s compile/piecewise/test_toy_llama.py
-
pytest -v -s compile/piecewise/test_full_cudagraph.py
-
label
:
PyTorch Fullgraph Test
# 18min
-
label
:
PyTorch Fullgraph Test
# 18min
mirror_hardwares
:
[
amdexperimental
,
amdproduction
]
mirror_hardwares
:
[
amdexperimental
,
amdproduction
]
...
...
tests/compile/piecewise/test_full_cudagraph.py
View file @
b124e108
...
@@ -7,6 +7,7 @@ import pytest
...
@@ -7,6 +7,7 @@ import pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
from
vllm.platforms
import
current_platform
MODEL
=
"Qwen/Qwen2-1.5B-Instruct"
MODEL
=
"Qwen/Qwen2-1.5B-Instruct"
...
@@ -37,7 +38,7 @@ def full_cudagraph_llm():
...
@@ -37,7 +38,7 @@ def full_cudagraph_llm():
"VLLM_FLASH_ATTN_VERSION"
:
"3"
"VLLM_FLASH_ATTN_VERSION"
:
"3"
}):
}):
return
LLM
(
model
=
MODEL
,
return
LLM
(
model
=
MODEL
,
gpu_memory_utilization
=
0.
2
,
gpu_memory_utilization
=
0.
3
,
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
))
compilation_config
=
CompilationConfig
(
full_cuda_graph
=
True
))
...
@@ -48,7 +49,7 @@ def piecewise_llm():
...
@@ -48,7 +49,7 @@ def piecewise_llm():
"VLLM_FLASH_ATTN_VERSION"
:
"3"
"VLLM_FLASH_ATTN_VERSION"
:
"3"
}):
}):
return
LLM
(
model
=
MODEL
,
return
LLM
(
model
=
MODEL
,
gpu_memory_utilization
=
0.
5
,
gpu_memory_utilization
=
0.
6
,
compilation_config
=
CompilationConfig
())
compilation_config
=
CompilationConfig
())
...
@@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
...
@@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
return
llm
.
generate
(
prompts
,
sampling_params
)
return
llm
.
generate
(
prompts
,
sampling_params
)
@
pytest
.
mark
.
skipif
(
current_platform
.
get_device_capability
()
!=
(
9
,
0
),
reason
=
"Only Hopper GPUs support FlashAttention 3"
)
@
pytest
.
mark
.
parametrize
((
"batch_size"
,
"max_tokens"
),
[(
1
,
10
),
(
7
,
10
),
@
pytest
.
mark
.
parametrize
((
"batch_size"
,
"max_tokens"
),
[(
1
,
10
),
(
7
,
10
),
(
16
,
10
),
(
25
,
10
),
(
16
,
10
),
(
25
,
10
),
(
32
,
10
),
(
45
,
10
),
(
32
,
10
),
(
45
,
10
),
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
b124e108
...
@@ -307,13 +307,14 @@ class FlashAttentionMetadataBuilder:
...
@@ -307,13 +307,14 @@ class FlashAttentionMetadataBuilder:
self
.
kv_cache_spec
=
kv_cache_spec
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
self
.
block_table
=
block_table
if
get_flash_attn_version
()
==
3
:
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
aot_schedule
=
not
compilation_config
.
full_cuda_graph
self
.
use_full_cuda_graph
=
compilation_config
.
full_cuda_graph
if
not
self
.
aot_schedule
:
if
self
.
use_full_cuda_graph
and
not
self
.
aot_schedule
:
logger
.
warning
(
raise
ValueError
(
"Full CUDA graph mode requires AOT scheduling, "
"AOT Schedule is disabled when using full_cuda_graph"
)
"which requires FlashAttention 3."
)
else
:
self
.
scheduler_metadata
=
torch
.
zeros
(
self
.
runner
.
max_num_reqs
+
1
,
self
.
aot_schedule
=
False
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
# Sliding window size to be used with the AOT scheduler will be
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
# populated on first build() call.
...
@@ -326,7 +327,7 @@ class FlashAttentionMetadataBuilder:
...
@@ -326,7 +327,7 @@ class FlashAttentionMetadataBuilder:
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
):
common_attn_metadata
:
CommonAttentionMetadata
):
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
max_seq_len
=
int
(
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
)
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
block_table
=
self
.
block_table
block_table
=
self
.
block_table
...
@@ -448,6 +449,18 @@ class FlashAttentionMetadataBuilder:
...
@@ -448,6 +449,18 @@ class FlashAttentionMetadataBuilder:
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
causal
=
True
)
causal
=
True
)
if
self
.
use_full_cuda_graph
:
assert
scheduler_metadata
is
not
None
n
=
scheduler_metadata
.
shape
[
0
]
self
.
scheduler_metadata
[:
n
].
copy_
(
scheduler_metadata
,
non_blocking
=
True
)
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self
.
scheduler_metadata
[
n
:]
=
0
scheduler_metadata
=
self
.
scheduler_metadata
[:
n
]
attn_metadata
=
FlashAttentionMetadata
(
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b124e108
...
@@ -1750,6 +1750,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1750,6 +1750,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
attn_metadata
:
Optional
[
dict
[
str
,
Any
]]
=
None
else
:
else
:
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
# Make sure max_model_len is used at the graph capture time.
self
.
seq_lens_np
[:
num_reqs
]
=
self
.
max_model_len
self
.
seq_lens_np
[
num_reqs
:]
=
0
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment