Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b22290c
Unverified
Commit
2b22290c
authored
Mar 20, 2025
by
Woosuk Kwon
Committed by
GitHub
Mar 20, 2025
Browse files
[V1] Add flag to disable cascade attention (#15243)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
d8e82bc0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
5 deletions
+23
-5
vllm/config.py
vllm/config.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+12
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+9
-5
No files found.
vllm/config.py
View file @
2b22290c
...
@@ -246,6 +246,7 @@ class ModelConfig:
...
@@ -246,6 +246,7 @@ class ModelConfig:
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
20
,
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
disable_sliding_window
:
bool
=
False
,
disable_cascade_attn
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
served_model_name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
...
@@ -322,6 +323,7 @@ class ModelConfig:
...
@@ -322,6 +323,7 @@ class ModelConfig:
self
.
max_seq_len_to_capture
=
max_seq_len_to_capture
self
.
max_seq_len_to_capture
=
max_seq_len_to_capture
self
.
max_logprobs
=
max_logprobs
self
.
max_logprobs
=
max_logprobs
self
.
disable_sliding_window
=
disable_sliding_window
self
.
disable_sliding_window
=
disable_sliding_window
self
.
disable_cascade_attn
=
disable_cascade_attn
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
enable_sleep_mode
=
enable_sleep_mode
self
.
enable_sleep_mode
=
enable_sleep_mode
...
...
vllm/engine/arg_utils.py
View file @
2b22290c
...
@@ -120,6 +120,7 @@ class EngineArgs:
...
@@ -120,6 +120,7 @@ class EngineArgs:
block_size
:
Optional
[
int
]
=
None
block_size
:
Optional
[
int
]
=
None
enable_prefix_caching
:
Optional
[
bool
]
=
None
enable_prefix_caching
:
Optional
[
bool
]
=
None
disable_sliding_window
:
bool
=
False
disable_sliding_window
:
bool
=
False
disable_cascade_attn
:
bool
=
False
use_v2_block_manager
:
bool
=
True
use_v2_block_manager
:
bool
=
True
swap_space
:
float
=
4
# GiB
swap_space
:
float
=
4
# GiB
cpu_offload_gb
:
float
=
0
# GiB
cpu_offload_gb
:
float
=
0
# GiB
...
@@ -1096,6 +1097,16 @@ class EngineArgs:
...
@@ -1096,6 +1097,16 @@ class EngineArgs:
"using. This is used to parse the reasoning content into OpenAI "
"using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``."
)
"API format. Required for ``--enable-reasoning``."
)
parser
.
add_argument
(
"--disable-cascade-attn"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Disable cascade attention for V1. While cascade attention "
"does not change the mathematical correctness, disabling it "
"could be useful for preventing potential numerical issues. "
"Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial."
)
return
parser
return
parser
@
classmethod
@
classmethod
...
@@ -1141,6 +1152,7 @@ class EngineArgs:
...
@@ -1141,6 +1152,7 @@ class EngineArgs:
max_seq_len_to_capture
=
self
.
max_seq_len_to_capture
,
max_seq_len_to_capture
=
self
.
max_seq_len_to_capture
,
max_logprobs
=
self
.
max_logprobs
,
max_logprobs
=
self
.
max_logprobs
,
disable_sliding_window
=
self
.
disable_sliding_window
,
disable_sliding_window
=
self
.
disable_sliding_window
,
disable_cascade_attn
=
self
.
disable_cascade_attn
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
served_model_name
=
self
.
served_model_name
,
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
2b22290c
...
@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
self
))
weakref
.
proxy
(
self
))
self
.
cascade_attn_enabled
=
not
self
.
model_config
.
disable_cascade_attn
# Multi-modal data support
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
input_registry
=
INPUT_REGISTRY
...
@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
non_blocking
=
True
)
# Prepare for cascade attention if needed.
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
if
self
.
cascade_attn_enabled
:
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
,
scheduler_output
.
num_common_prefix_blocks
,
)
)
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment