Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7cfa4b24
Unverified
Commit
7cfa4b24
authored
Oct 03, 2025
by
Angela Yi
Committed by
GitHub
Oct 03, 2025
Browse files
[BugFix] Fix de-functionalization pass for rotary_embedding (#23953)
Signed-off-by:
angelayi
<
yiangela7@gmail.com
>
parent
b71fcd49
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
266 additions
and
87 deletions
+266
-87
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/compile/test_functionalization.py
tests/compile/test_functionalization.py
+228
-70
vllm/compilation/fix_functionalization.py
vllm/compilation/fix_functionalization.py
+37
-17
No files found.
.buildkite/test-pipeline.yaml
View file @
7cfa4b24
...
...
@@ -397,6 +397,7 @@ steps:
-
pytest -v -s compile/test_pass_manager.py
-
pytest -v -s compile/test_fusion.py
-
pytest -v -s compile/test_fusion_attn.py
-
pytest -v -s compile/test_functionalization.py
-
pytest -v -s compile/test_silu_mul_quant_fusion.py
-
pytest -v -s compile/test_sequence_parallelism.py
-
pytest -v -s compile/test_async_tp.py
...
...
tests/compile/test_functionalization.py
View file @
7cfa4b24
...
...
@@ -5,54 +5,237 @@ import pytest
import
torch
import
vllm.envs
as
envs
from
vllm
import
LLM
,
SamplingParams
from
vllm.compilation.activation_quant_fusion
import
ActivationQuantFusionPass
from
vllm.compilation.fix_functionalization
import
FixFunctionalizationPass
from
vllm.compilation.fusion
import
FUSED_OPS
,
RMSNormQuantFusionPass
from
vllm.compilation.fusion
import
RMSNormQuantFusionPass
from
vllm.compilation.fx_utils
import
find_auto_fn
,
find_auto_fn_maybe
,
is_func
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.compilation.post_cleanup
import
PostCleanupPass
from
vllm.config
import
CompilationConfig
,
PassConfig
,
VllmConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
)
GroupShape
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
from
.backend
import
TestBackend
OPS_IN_MODEL
=
[
torch
.
ops
.
_C
.
rotary_embedding
.
default
,
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
]
TEST_FP8
=
current_platform
.
supports_fp8
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
TestSiluMul
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
=
128
):
super
().
__init__
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
if
TEST_FP8
:
self
.
w
=
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
)
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
if
TEST_FP8
:
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
wscale
)
return
x2
else
:
return
y
def
example_inputs
(
self
,
num_tokens
=
32
,
hidden_size
=
128
):
dtype
=
torch
.
float16
if
TEST_FP8
else
torch
.
float32
return
(
torch
.
rand
(
num_tokens
,
hidden_size
*
2
,
dtype
=
dtype
),
)
def
ops_in_model
(
self
,
do_fusion
):
if
TEST_FP8
and
do_fusion
:
return
[
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
]
else
:
return
[
torch
.
ops
.
_C
.
silu_and_mul
.
default
]
def
ops_not_in_model
(
self
):
return
[]
class
TestFusedAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
=
16
,
intermediate_size
=
32
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
dtype
=
torch
.
float16
if
TEST_FP8
else
torch
.
float32
self
.
gate_proj
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
intermediate_size
,
hidden_size
),
dtype
=
dtype
))
self
.
norm
=
RMSNorm
(
intermediate_size
,
1e-05
)
self
.
norm
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
intermediate_size
,
dtype
=
dtype
))
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
if
TEST_FP8
:
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
w
=
torch
.
rand
(
hidden_size
,
intermediate_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
def
forward
(
self
,
hidden_states
,
residual
):
# Reshape input
view
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
# matrix multiplication
permute
=
self
.
gate_proj
.
permute
(
1
,
0
)
mm
=
torch
.
mm
(
view
,
permute
)
# layer normalization
norm_output
,
residual_output
=
self
.
norm
(
mm
,
residual
)
if
TEST_FP8
:
# scaled_mm with static input quantization
fp8_linear_result
=
self
.
fp8_linear
.
apply
(
norm_output
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
scale
.
to
(
norm_output
.
device
),
)
return
fp8_linear_result
,
residual_output
else
:
return
norm_output
,
residual_output
def
example_inputs
(
self
,
batch_size
=
8
,
hidden_size
=
16
,
seq_len
=
16
):
dtype
=
torch
.
float16
if
TEST_FP8
else
torch
.
float32
hidden_states
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
)
residual
=
torch
.
randn
((
batch_size
*
seq_len
,
hidden_size
),
dtype
=
dtype
)
return
(
hidden_states
,
residual
)
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
def
ops_in_model
(
self
,
do_fusion
):
if
TEST_FP8
and
do_fusion
:
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
]
else
:
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
]
RMS_QUANT_OPS
=
{
"static_fp8"
:
[
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
.
default
,
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
],
}
def
ops_not_in_model
(
self
):
return
[]
SILU_MUL_OP
=
torch
.
ops
.
_C
.
silu_and_mul
.
default
SILU_MUL_QUANT_OP
=
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
class
TestRotaryEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
head_dim
=
64
,
rotary_dim
=
None
,
max_position
=
2048
,
base
=
10000
):
super
().
__init__
()
self
.
head_dim
=
head_dim
self
.
rotary_dim
=
rotary_dim
or
head_dim
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_dim
,
max_position
=
max_position
,
base
=
base
,
)
def
forward
(
self
,
positions
,
q
,
k
):
q_rotated
,
k_rotated
=
self
.
rotary_emb
(
positions
,
q
,
k
)
return
q_rotated
,
k_rotated
def
example_inputs
(
self
,
num_tokens
=
32
,
head_dim
=
64
):
dtype
=
torch
.
float16
positions
=
torch
.
arange
(
num_tokens
,
dtype
=
torch
.
long
)
q
=
torch
.
randn
(
num_tokens
,
head_dim
,
dtype
=
dtype
)
k
=
torch
.
randn
(
num_tokens
,
head_dim
,
dtype
=
dtype
)
return
(
positions
,
q
,
k
)
def
ops_in_model
(
self
,
do_fusion
):
return
[
torch
.
ops
.
_C
.
rotary_embedding
.
default
]
def
ops_not_in_model
(
self
):
return
[]
class
TestRotaryEmbeddingSliceScatter
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
head_dim
=
64
,
num_heads
=
4
,
max_position
=
2048
,
base
=
10000
):
super
().
__init__
()
self
.
head_dim
=
head_dim
self
.
num_heads
=
num_heads
self
.
hidden_size
=
head_dim
*
num_heads
self
.
qkv_proj
=
torch
.
nn
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
*
3
,
bias
=
False
,
dtype
=
torch
.
float16
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
base
,
)
def
forward
(
self
,
positions
,
hidden_states
):
# Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
# -> slice_scatter -> split_with_sizes
qkv
=
self
.
qkv_proj
(
hidden_states
)
split_sizes
=
[
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
]
q
,
k
,
v
=
torch
.
split
(
qkv
,
split_sizes
,
dim
=-
1
)
q_rotated
,
k_rotated
=
self
.
rotary_emb
(
positions
,
q
,
k
)
qkv_updated
=
torch
.
cat
([
q_rotated
,
k_rotated
,
v
],
dim
=-
1
)
return
qkv_updated
def
example_inputs
(
self
,
num_tokens
=
32
,
head_dim
=
64
,
num_heads
=
4
):
dtype
=
torch
.
float16
hidden_size
=
head_dim
*
num_heads
positions
=
torch
.
arange
(
num_tokens
,
dtype
=
torch
.
long
)
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
return
(
positions
,
hidden_states
)
def
ops_in_model
(
self
,
do_fusion
):
return
[
torch
.
ops
.
_C
.
rotary_embedding
.
default
]
def
ops_not_in_model
(
self
):
return
[
torch
.
ops
.
aten
.
slice_scatter
.
default
]
MODELS
=
[
TestSiluMul
,
TestFusedAddRMSNorm
,
TestRotaryEmbedding
,
TestRotaryEmbeddingSliceScatter
,
]
@
pytest
.
mark
.
parametrize
(
"model, quant_key"
,
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
,
kFp8StaticTensorSym
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e"
,
kFp8DynamicTokenSym
)])
@
pytest
.
mark
.
parametrize
(
"model_class"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"do_fusion"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
!=
"cuda"
,
reason
=
"Only test on CUDA"
)
def
test_fix_functionalization
(
model
:
str
,
quant_key
:
QuantKey
,
do_fusion
:
bool
):
def
test_fix_functionalization
(
model_class
:
torch
.
nn
.
Module
,
do_fusion
:
bool
):
torch
.
set_default_device
(
"cuda"
)
vllm_config
=
VllmConfig
()
...
...
@@ -63,56 +246,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
cleanup_pass
=
PostCleanupPass
(
vllm_config
)
act_quant_fusion_pass
=
ActivationQuantFusionPass
(
vllm_config
)
passes
=
[
noop_pass
,
fusion_pass
,
act_quant_fusion_pass
,
cleanup_pass
]
if
do_fusion
else
[
noop_pass
,
cleanup_pass
]
passes
=
(
[
noop_pass
,
fusion_pass
,
act_quant_fusion_pass
,
cleanup_pass
]
if
do_fusion
else
[
noop_pass
,
cleanup_pass
]
)
func_pass
=
FixFunctionalizationPass
(
vllm_config
)
backend_func
=
TestBackend
(
*
passes
,
func_pass
)
backend_no_func
=
TestBackend
(
*
passes
)
# instantiate a full engine and manually compile the model 2x
# (with and without FixFunctionalizationPass)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
)
model_runner
=
llm
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
orig_model
=
model_runner
.
model
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
# Can only do that by using the decorator but then we'd have to instantiate
# 2 LLM instances.
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
model_runner
.
model
=
torch
.
compile
(
orig_model
,
fullgraph
=
True
,
backend
=
backend_func
)
gen_func
=
llm
.
generate
(
prompts
,
sampling_params
)
model_runner
.
model
=
torch
.
compile
(
orig_model
,
fullgraph
=
True
,
backend
=
backend_no_func
)
gen_no_func
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output_func
,
output_no_func
in
zip
(
gen_func
,
gen_no_func
):
assert
output_func
.
outputs
[
0
].
text
==
output_no_func
.
outputs
[
0
].
text
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops
=
[
FUSED_OPS
[(
quant_key
,
True
)],
FUSED_OPS
[(
quant_key
,
False
)]
]
if
do_fusion
else
[
RMS_OP
]
silu_mul_ops
=
[
SILU_MUL_QUANT_OP
]
if
do_fusion
and
\
quant_key
==
kFp8StaticTensorSym
else
[
SILU_MUL_OP
]
ops
=
OPS_IN_MODEL
+
rms_ops
+
silu_mul_ops
for
op
in
ops
:
model
=
model_class
()
torch
.
compile
(
model
,
backend
=
backend_func
)(
*
model
.
example_inputs
())
torch
.
compile
(
model
,
backend
=
backend_no_func
)(
*
model
.
example_inputs
())
# check if the functionalization pass is applied
for
op
in
model
.
ops_in_model
(
do_fusion
):
find_auto_fn
(
backend_no_func
.
graph_post_pass
.
nodes
,
op
)
assert
find_auto_fn_maybe
(
backend_func
.
graph_post_pass
.
nodes
,
op
)
is
None
# noqa: E501
assert
(
find_auto_fn_maybe
(
backend_func
.
graph_post_pass
.
nodes
,
op
)
is
None
)
# noqa: E501
# make sure the ops were all de-functionalized
found
=
dict
()
for
node
in
backend_func
.
graph_post_pass
.
nodes
:
for
op
in
ops
:
for
op
in
model
.
ops_in_model
(
do_fusion
):
if
is_func
(
node
,
op
):
found
[
op
]
=
True
for
op
in
model
.
ops_not_in_model
():
if
is_func
(
node
,
op
):
found
[
op
]
=
True
assert
all
(
found
[
op
]
for
op
in
ops
)
assert
all
(
found
[
op
]
for
op
in
model
.
ops_in_model
(
do_fusion
))
assert
all
(
not
found
.
get
(
op
)
for
op
in
model
.
ops_not_in_model
())
vllm/compilation/fix_functionalization.py
View file @
7cfa4b24
...
...
@@ -46,14 +46,25 @@ class FixFunctionalizationPass(VllmInductorPass):
if
at_target
==
torch
.
ops
.
_C
.
rotary_embedding
.
default
:
query
=
kwargs
[
'query'
]
mm_node
=
query
.
args
[
0
].
args
[
0
]
key
=
kwargs
[
'key'
]
getitem_nodes
=
self
.
getitem_users
(
node
)
if
(
is_func
(
query
,
operator
.
getitem
)
and
is_func
(
key
,
operator
.
getitem
)
and
query
.
args
[
0
]
==
key
.
args
[
0
]
and
is_func
(
query
.
args
[
0
],
torch
.
ops
.
aten
.
split_with_sizes
.
default
)
and
all
(
is_func
(
user
,
torch
.
ops
.
aten
.
slice_scatter
.
default
)
for
getitem_node
in
getitem_nodes
.
values
()
for
user
in
getitem_node
.
users
)):
# Pattern where query and key are slices of an mm_node.
# While functionalized, results at [1] and [2] are scattered
# back into mm_node. So after de-functionalization, we can
# just use mm_node directly.
# rotary_embedding is a special case: the two mutating inputs
# are query and key, which are slices of mm_node.
# While functionalized, results at[1] and at[2] are scattered
# back into mm_node. After de-functionalization, we can just
# use mm_node directly.
for
idx
,
user
in
self
.
getitem_users
(
node
).
items
():
mm_node
=
query
.
args
[
0
].
args
[
0
]
for
user
in
getitem_nodes
.
values
():
for
user_of_getitem
in
user
.
users
:
if
is_func
(
user_of_getitem
,
torch
.
ops
.
aten
.
slice_scatter
.
default
):
...
...
@@ -64,6 +75,15 @@ class FixFunctionalizationPass(VllmInductorPass):
self
.
insert_defunctionalized
(
graph
,
node
)
self
.
_remove
(
node
)
else
:
# Directly replace the auto_functionalize(rotary_embedding)
# with the inplace rotary_embedding. In theory, we shouldn't
# do this blindly, but in practice in vLLM it's ok. The best
# solution is to use auto_functionalization_v2 and then use
# inductor's builtin defunctionalization (reinplacing) pass.
mutated_args
=
{
1
:
'query'
,
2
:
'key'
}
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
# rms_norm replacements avoid the most copies for LLaMa.
elif
at_target
==
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
:
mutated_args
=
{
1
:
'input'
,
2
:
'residual'
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment