Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af473f0a
Unverified
Commit
af473f0a
authored
Aug 08, 2025
by
Po-Han Huang (NVIDIA)
Committed by
GitHub
Aug 07, 2025
Browse files
[bugfix] Fix Llama3/4 issues caused by FlashInfer 0.2.10 (#22426)
Signed-off-by:
Po-Han Huang
<
pohanh@nvidia.com
>
parent
157f9c13
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
9 deletions
+18
-9
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
...el_executor/layers/quantization/utils/flashinfer_utils.py
+16
-8
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+2
-1
No files found.
vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
View file @
af473f0a
...
@@ -6,14 +6,22 @@ import torch
...
@@ -6,14 +6,22 @@ import torch
def
calculate_tile_tokens_dim
(
num_tokens
,
top_k
,
num_experts
):
def
calculate_tile_tokens_dim
(
num_tokens
,
top_k
,
num_experts
):
from
flashinfer
import
next_positive_power_of_2
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# Guess tokens per expert assuming perfect expert distribution first.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
# with the necessary kernels is released.
# And pad the number to the next power of 2.
tile_tokens_dim
=
8
tile_tokens_dim
=
next_positive_power_of_2
(
num_tokens_per_expert
)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
# from flashinfer import next_positive_power_of_2
tile_tokens_dim
=
min
(
max
(
tile_tokens_dim
,
8
),
64
)
# # Guess tokens per expert assuming perfect expert distribution first.
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
# # And pad the number to the next power of 2.
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
# # kernel.
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return
tile_tokens_dim
return
tile_tokens_dim
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
af473f0a
...
@@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
head_dim
=
self
.
kv_cache_spec
.
head_size
head_dim
=
self
.
kv_cache_spec
.
head_size
# currently prefill trtllm attention does not support fp8 kv cache
# currently prefill trtllm attention does not support fp8 kv cache
prefill_use_trtllm
=
use_trtllm_attention
(
prefill_use_trtllm
=
not
cache_dtype
.
startswith
(
"fp8"
)
\
and
use_trtllm_attention
(
num_prefill_tokens
,
max_seq_len
,
cache_dtype
,
num_prefill_tokens
,
max_seq_len
,
cache_dtype
,
num_qo_heads
,
num_kv_heads
,
head_dim
)
num_qo_heads
,
num_kv_heads
,
head_dim
)
decode_use_trtllm
=
use_trtllm_attention
(
decode_use_trtllm
=
use_trtllm_attention
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment