Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb9f9730
Unverified
Commit
bb9f9730
authored
Feb 09, 2026
by
Charlie Fu
Committed by
GitHub
Feb 09, 2026
Browse files
[torch.compile][Fusion] Fix attention fusion pass removing kv_udpate op. (#33945)
Signed-off-by:
charlifu
<
charlifu@amd.com
>
parent
4d396509
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
1 deletion
+22
-1
tests/compile/passes/test_fusion_attn.py
tests/compile/passes/test_fusion_attn.py
+12
-1
vllm/compilation/passes/fusion/attn_quant_fusion.py
vllm/compilation/passes/fusion/attn_quant_fusion.py
+10
-0
No files found.
tests/compile/passes/test_fusion_attn.py
View file @
bb9f9730
...
@@ -267,7 +267,7 @@ elif current_platform.is_rocm():
...
@@ -267,7 +267,7 @@ elif current_platform.is_rocm():
PATTERN_TEST_MODELS_FP8
=
[
PATTERN_TEST_MODELS_FP8
=
[
(
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
TestAttentionFp8StaticQuantPatternModel
)
(
"amd/Llama-3.1-8B-Instruct-FP8-KV"
,
TestAttentionFp8StaticQuantPatternModel
)
]
]
BACKENDS
=
[
BACKENDS
_FP8
=
[
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
,
AttentionBackendEnum
.
ROCM_ATTN
,
AttentionBackendEnum
.
ROCM_ATTN
,
AttentionBackendEnum
.
TRITON_ATTN
,
AttentionBackendEnum
.
TRITON_ATTN
,
...
@@ -474,6 +474,17 @@ def test_attention_quant_pattern(
...
@@ -474,6 +474,17 @@ def test_attention_quant_pattern(
assert
attn_nodes_pre
[
0
].
kwargs
.
get
(
"output_block_scale"
)
is
None
,
(
assert
attn_nodes_pre
[
0
].
kwargs
.
get
(
"output_block_scale"
)
is
None
,
(
"Attention should not have output_block_scale before fusion"
"Attention should not have output_block_scale before fusion"
)
)
kv_cache_dummy_dep_pre_is_none
=
(
attn_nodes_pre
[
0
].
kwargs
.
get
(
"kv_cache_dummy_dep"
)
is
None
)
kv_cache_dummy_dep_post_is_none
=
(
attn_nodes_post
[
0
].
kwargs
.
get
(
"kv_cache_dummy_dep"
)
is
None
)
assert
not
(
kv_cache_dummy_dep_pre_is_none
^
kv_cache_dummy_dep_post_is_none
),
(
"The kv_cache_dummy_dep should be consistent before and after fusion"
)
if
quant_key
.
dtype
==
FP8_DTYPE
:
if
quant_key
.
dtype
==
FP8_DTYPE
:
assert
attn_nodes_post
[
0
].
kwargs
.
get
(
"output_block_scale"
)
is
None
,
(
assert
attn_nodes_post
[
0
].
kwargs
.
get
(
"output_block_scale"
)
is
None
,
(
"Attention should not have output_block_scale after FP8 fusion"
"Attention should not have output_block_scale after FP8 fusion"
...
...
vllm/compilation/passes/fusion/attn_quant_fusion.py
View file @
bb9f9730
...
@@ -142,6 +142,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
...
@@ -142,6 +142,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
at1
=
auto_functionalized
(
at1
=
auto_functionalized
(
ATTN_OP
,
ATTN_OP
,
...
@@ -152,6 +153,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
...
@@ -152,6 +153,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer_name
=
self
.
layer_name
,
layer_name
=
self
.
layer_name
,
output_scale
=
None
,
output_scale
=
None
,
output_block_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
)
attn_out_view
=
RESHAPE_OP
(
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
q
.
shape
[
0
],
self
.
num_heads
*
self
.
head_size
]
at1
[
1
],
[
q
.
shape
[
0
],
self
.
num_heads
*
self
.
head_size
]
...
@@ -165,6 +167,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
...
@@ -165,6 +167,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
output_attn
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# attn output in quant_dtype
# attn output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
...
@@ -182,6 +185,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
...
@@ -182,6 +185,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer_name
=
self
.
layer_name
,
layer_name
=
self
.
layer_name
,
output_scale
=
scale
,
output_scale
=
scale
,
output_block_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
)
return
RESHAPE_OP
(
at1
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
])
return
RESHAPE_OP
(
at1
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
])
...
@@ -191,6 +195,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
...
@@ -191,6 +195,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
self
.
empty
(
5
,
self
.
num_heads
,
self
.
head_size
),
# v
self
.
empty
(
5
,
self
.
num_heads
,
self
.
head_size
),
# v
self
.
empty
(
5
,
self
.
num_heads
,
self
.
head_size
),
# attn_output
self
.
empty
(
5
,
self
.
num_heads
,
self
.
head_size
),
# attn_output
empty_fp32
(
1
,
1
),
# scale
empty_fp32
(
1
,
1
),
# scale
self
.
empty
(
0
),
# kv_cache_dummy_dep
]
]
pm
.
register_replacement
(
pm
.
register_replacement
(
...
@@ -228,6 +233,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -228,6 +233,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
at1
=
auto_functionalized
(
at1
=
auto_functionalized
(
ATTN_OP
,
ATTN_OP
,
...
@@ -238,6 +244,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -238,6 +244,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
layer_name
=
self
.
layer_name
,
layer_name
=
self
.
layer_name
,
output_scale
=
None
,
output_scale
=
None
,
output_block_scale
=
None
,
output_block_scale
=
None
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
)
attn_out_view
=
RESHAPE_OP
(
attn_out_view
=
RESHAPE_OP
(
at1
[
1
],
[
q
.
shape
[
0
],
self
.
num_heads
*
self
.
head_size
]
at1
[
1
],
[
q
.
shape
[
0
],
self
.
num_heads
*
self
.
head_size
]
...
@@ -261,6 +268,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -261,6 +268,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant
:
torch
.
Tensor
,
output_quant
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
kv_cache_dummy_dep
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# attention output in quant_dtype
# attention output in quant_dtype
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
output_attn
=
torch
.
ops
.
aten
.
full
.
default
(
...
@@ -280,6 +288,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -280,6 +288,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
layer_name
=
self
.
layer_name
,
layer_name
=
self
.
layer_name
,
output_scale
=
input_scale
,
output_scale
=
input_scale
,
output_block_scale
=
output_scale_view
,
output_block_scale
=
output_scale_view
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
)
output
=
RESHAPE_OP
(
at2
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
//
2
])
output
=
RESHAPE_OP
(
at2
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
//
2
])
return
output
,
at2
[
2
]
return
output
,
at2
[
2
]
...
@@ -294,6 +303,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -294,6 +303,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
128
,
round_up
(
self
.
num_heads
*
self
.
head_size
//
16
,
4
)
128
,
round_up
(
self
.
num_heads
*
self
.
head_size
//
16
,
4
)
),
# output_scale
),
# output_scale
empty_fp32
(
1
,
1
),
# input_scale
empty_fp32
(
1
,
1
),
# input_scale
self
.
empty
(
0
),
# kv_cache_dummy_dep
]
]
pm
.
register_replacement
(
pm
.
register_replacement
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment