Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e816a881
Unverified
Commit
e816a881
authored
Apr 10, 2026
by
yzong-rh
Committed by
GitHub
Apr 10, 2026
Browse files
[Bugfix] Fix FlashInfer crash with kv_cache_dtype_skip_layers (#39002)
Signed-off-by:
Yifan Zong
<
yzong@redhat.com
>
parent
e281cb72
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
11 deletions
+15
-11
tests/compile/passes/test_fusion_attn.py
tests/compile/passes/test_fusion_attn.py
+5
-8
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+10
-3
No files found.
tests/compile/passes/test_fusion_attn.py
View file @
e816a881
...
@@ -39,7 +39,7 @@ from vllm.platforms import current_platform
...
@@ -39,7 +39,7 @@ from vllm.platforms import current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
get_kv_quant_mode
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
FP4_DTYPE
=
torch
.
uint8
...
@@ -53,7 +53,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
...
@@ -53,7 +53,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_qo_heads
:
int
,
num_qo_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
kv_cache_dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
block_size
:
int
,
block_size
:
int
,
...
@@ -63,7 +62,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
...
@@ -63,7 +62,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
self
.
num_qo_heads
=
num_qo_heads
self
.
num_qo_heads
=
num_qo_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
device
=
device
self
.
device
=
device
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
dtype
=
vllm_config
.
model_config
.
dtype
...
@@ -81,13 +79,14 @@ class AttentionQuantPatternModel(torch.nn.Module):
...
@@ -81,13 +79,14 @@ class AttentionQuantPatternModel(torch.nn.Module):
self
.
block_size
=
block_size
self
.
block_size
=
block_size
# Initialize attn MetadataBuilder
# Initialize attn MetadataBuilder
(match Attention.get_kv_cache_spec)
self
.
builder
=
self
.
attn
.
attn_backend
.
get_builder_cls
()(
self
.
builder
=
self
.
attn
.
attn_backend
.
get_builder_cls
()(
kv_cache_spec
=
AttentionSpec
(
kv_cache_spec
=
AttentionSpec
(
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
head_size
=
self
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
attn
.
kv_cache_torch_dtype
,
kv_quant_mode
=
get_kv_quant_mode
(
self
.
attn
.
kv_cache_dtype
),
),
),
layer_names
=
[
self
.
attn
.
layer_name
],
layer_names
=
[
self
.
attn
.
layer_name
],
vllm_config
=
self
.
vllm_config
,
vllm_config
=
self
.
vllm_config
,
...
@@ -126,7 +125,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
...
@@ -126,7 +125,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
# Create dummy KV cache
# Create dummy KV cache
raw_tensor
=
torch
.
zeros
(
raw_tensor
=
torch
.
zeros
(
2
*
num_blocks
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
,
2
*
num_blocks
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
attn
.
kv_cache_
torch_
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
raw_tensor
=
raw_tensor
.
view
(
kv_cache_shape
)
raw_tensor
=
raw_tensor
.
view
(
kv_cache_shape
)
...
@@ -348,7 +347,6 @@ def test_attention_quant_pattern(
...
@@ -348,7 +347,6 @@ def test_attention_quant_pattern(
num_qo_heads
=
num_qo_heads
,
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
head_size
=
head_size
,
kv_cache_dtype
=
FP8_DTYPE
,
device
=
device
,
device
=
device
,
vllm_config
=
vllm_config_unfused
,
vllm_config
=
vllm_config_unfused
,
block_size
=
block_size
,
block_size
=
block_size
,
...
@@ -376,7 +374,6 @@ def test_attention_quant_pattern(
...
@@ -376,7 +374,6 @@ def test_attention_quant_pattern(
num_qo_heads
=
num_qo_heads
,
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
head_size
=
head_size
,
kv_cache_dtype
=
FP8_DTYPE
,
device
=
device
,
device
=
device
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
w
=
model_unfused
.
w
,
w
=
model_unfused
.
w
,
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
e816a881
...
@@ -63,7 +63,11 @@ from vllm.v1.attention.backends.utils import (
...
@@ -63,7 +63,11 @@ from vllm.v1.attention.backends.utils import (
from
vllm.v1.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.v1.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.v1.attention.ops.dcp_alltoall
import
dcp_a2a_lse_reduce
from
vllm.v1.attention.ops.dcp_alltoall
import
dcp_a2a_lse_reduce
from
vllm.v1.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.v1.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
UniformTypeKVCacheSpecs
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVQuantMode
,
UniformTypeKVCacheSpecs
,
)
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
=
2048
*
1024
*
1024
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
=
2048
*
1024
*
1024
...
@@ -600,12 +604,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -600,12 +604,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
head_dim
=
self
.
kv_cache_spec
.
head_size
self
.
head_dim
=
self
.
kv_cache_spec
.
head_size
self
.
page_size
=
self
.
kv_cache_spec
.
block_size
self
.
page_size
=
self
.
kv_cache_spec
.
block_size
self
.
cache_dtype
=
self
.
cache_config
.
cache_dtype
if
self
.
kv_cache_spec
.
kv_quant_mode
!=
KVQuantMode
.
NONE
:
if
is_quantized_kv_cache
(
self
.
cache_dtype
):
self
.
cache_dtype
=
self
.
cache_config
.
cache_dtype
# Cannot use self.kv_cache_spec.dtype here because kv_cache_spec
# storage dtype may not be the same as the op dtype (uint8 vs fp8_e4m3)
self
.
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
cache_dtype
self
.
cache_dtype
)
)
else
:
else
:
self
.
cache_dtype
=
"auto"
assert
self
.
kv_cache_spec
.
dtype
==
self
.
model_config
.
dtype
assert
self
.
kv_cache_spec
.
dtype
==
self
.
model_config
.
dtype
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment