Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
11eddf02
Unverified
Commit
11eddf02
authored
Aug 27, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 27, 2025
Browse files
[FlashInfer] Cache hyper params in metadata builder (#23732)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
04ff1e43
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
15 deletions
+15
-15
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+15
-15
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
11eddf02
...
@@ -214,6 +214,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -214,6 +214,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# TODO: discard this for trtllm-gen backend
# TODO: discard this for trtllm-gen backend
self
.
global_hyperparameters
=
infer_global_hyperparameters
(
self
.
global_hyperparameters
=
infer_global_hyperparameters
(
get_per_layer_parameters
(
vllm_config
,
layer_names
,
FlashInferImpl
))
get_per_layer_parameters
(
vllm_config
,
layer_names
,
FlashInferImpl
))
self
.
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
self
.
window_left
=
self
.
global_hyperparameters
.
window_left
self
.
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
self
.
has_sinks
=
self
.
global_hyperparameters
.
has_sinks
# Preparing persistent buffers (device-side)
# Preparing persistent buffers (device-side)
self
.
paged_kv_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
self
.
paged_kv_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
...
@@ -381,8 +385,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -381,8 +385,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
)
# Check if any layer uses sinks (requires TRTLLM attention)
# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks
=
self
.
global_hyperparameters
.
has_sinks
prefill_use_trtllm
=
use_trtllm_attention
(
self
.
num_qo_heads
,
prefill_use_trtllm
=
use_trtllm_attention
(
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
num_prefill_tokens
,
num_prefill_tokens
,
...
@@ -390,7 +392,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -390,7 +392,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
cache_dtype
,
self
.
cache_dtype
,
self
.
q_data_type
,
self
.
q_data_type
,
is_prefill
=
True
,
is_prefill
=
True
,
has_sinks
=
has_sinks
)
has_sinks
=
self
.
has_sinks
)
decode_use_trtllm
=
use_trtllm_attention
(
self
.
num_qo_heads
,
decode_use_trtllm
=
use_trtllm_attention
(
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
num_decode_tokens
,
num_decode_tokens
,
...
@@ -398,7 +400,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -398,7 +400,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
cache_dtype
,
self
.
cache_dtype
,
self
.
q_data_type
,
self
.
q_data_type
,
is_prefill
=
False
,
is_prefill
=
False
,
has_sinks
=
has_sinks
)
has_sinks
=
self
.
has_sinks
)
attn_metadata
=
FlashInferMetadata
(
attn_metadata
=
FlashInferMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
...
@@ -433,9 +435,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -433,9 +435,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
head_dim
,
self
.
head_dim
,
self
.
page_size
,
self
.
page_size
,
causal
=
True
,
causal
=
True
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
)
...
@@ -472,10 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -472,10 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
head_dim
,
self
.
head_dim
,
self
.
page_size
,
self
.
page_size
,
causal
=
True
,
causal
=
True
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
)
...
@@ -525,10 +526,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -525,10 +526,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
page_size
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
pos_encoding_mode
=
"NONE"
,
sm_scale
=
self
.
global_hyperparameters
.
sm_scale
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
global_hyperparameters
.
window_left
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
global_hyperparameters
.
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment