Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5b5f350d
Unverified
Commit
5b5f350d
authored
Aug 19, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 19, 2025
Browse files
[Misc] Enable yapf for FlashInfer backend (#23193)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
f7cf5b51
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
13 deletions
+24
-13
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+24
-13
No files found.
vllm/v1/attention/backends/flashinfer.py
View file @
5b5f350d
...
...
@@ -36,6 +36,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
get_per_layer_parameters
,
infer_global_hyperparameters
,
split_decodes_and_prefills
)
# yapf: enable
from
vllm.v1.kv_cache_interface
import
AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
...
...
@@ -541,12 +542,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
cache_dtype
.
startswith
(
"fp8"
)
and
enable_fusion
:
q_dtype
=
kv_cache_dtype
prefill_use_trtllm
=
use_trtllm_attention
(
num_qo_heads
,
num_kv_heads
,
num_prefill_tokens
,
max_seq_len
,
cache_dtype
,
q_dtype
,
is_prefill
=
True
,
has_sinks
=
has_sinks
)
decode_use_trtllm
=
use_trtllm_attention
(
num_qo_heads
,
num_kv_heads
,
num_decode_tokens
,
max_seq_len
,
cache_dtype
,
q_dtype
,
is_prefill
=
False
,
has_sinks
=
has_sinks
)
prefill_use_trtllm
=
use_trtllm_attention
(
num_qo_heads
,
num_kv_heads
,
num_prefill_tokens
,
max_seq_len
,
cache_dtype
,
q_dtype
,
is_prefill
=
True
,
has_sinks
=
has_sinks
)
decode_use_trtllm
=
use_trtllm_attention
(
num_qo_heads
,
num_kv_heads
,
num_decode_tokens
,
max_seq_len
,
cache_dtype
,
q_dtype
,
is_prefill
=
False
,
has_sinks
=
has_sinks
)
attn_metadata
=
FlashInferMetadata
(
num_actual_tokens
=
num_actual_tokens
,
...
...
@@ -654,19 +665,18 @@ class FlashInferImpl(AttentionImpl):
raise
ValueError
(
"Sinks must have the same number of heads as the number of "
f
"heads in the layer. Expected
{
num_heads
}
, but got "
f
"
{
sinks
.
shape
[
0
]
}
."
)
f
"
{
sinks
.
shape
[
0
]
}
."
)
self
.
sinks
=
sinks
self
.
support_trtllm_attn
=
(
supports_trtllm_attention
()
and
num_heads
%
num_kv_heads
==
0
)
self
.
support_trtllm_attn
=
(
supports_trtllm_attention
()
and
num_heads
%
num_kv_heads
==
0
)
self
.
bmm1_scale
:
Optional
[
float
]
=
None
self
.
bmm2_scale
:
Optional
[
float
]
=
None
def
fused_output_quant_supported
(
self
,
dtype
:
torch
.
dtype
,
static
:
bool
,
group_shape
:
GroupShape
):
supported_quant_type
=
(
dtype
==
FP8_DTYPE
and
static
and
group_shape
==
GroupShape
.
PER_TENSOR
)
supported_quant_type
=
(
dtype
==
FP8_DTYPE
and
static
and
group_shape
==
GroupShape
.
PER_TENSOR
)
return
(
self
.
support_trtllm_attn
and
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
supported_quant_type
)
...
...
@@ -731,7 +741,8 @@ class FlashInferImpl(AttentionImpl):
# Insert FP8 quant for query
num_tokens
,
num_heads
,
head_size
=
query
.
shape
query
,
_
=
ops
.
scaled_fp8_quant
(
query
.
reshape
((
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
query
.
reshape
(
(
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment