Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db74d604
Unverified
Commit
db74d604
authored
Aug 28, 2025
by
Angela Yi
Committed by
GitHub
Aug 28, 2025
Browse files
[Bugfix] Add fake mode around passes (#23349)
Signed-off-by:
angelayi
<
yiangela7@gmail.com
>
parent
95089607
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
66 additions
and
41 deletions
+66
-41
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+2
-0
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+6
-0
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+2
-0
vllm/compilation/fusion_attn.py
vllm/compilation/fusion_attn.py
+34
-41
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+20
-0
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+2
-0
No files found.
vllm/compilation/activation_quant_fusion.py
View file @
db74d604
...
...
@@ -10,6 +10,7 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
...
...
@@ -61,6 +62,7 @@ class ActivationQuantFusionPass(VllmInductorPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
...
...
vllm/compilation/collective_fusion.py
View file @
db74d604
...
...
@@ -19,6 +19,7 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -349,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class
AsyncTPPass
(
VllmInductorPass
):
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
...
...
@@ -1121,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
# in fallback path, when we don't use flashinfer
fuse_rms_quant
=
config
.
compilation_config
.
pass_config
.
enable_fusion
)
self
.
register_patterns
()
@
enable_fake_mode
def
register_patterns
(
self
):
for
epsilon
in
[
1e-5
,
1e-6
]:
AllReduceFusedRMSNormStaticQuantFP8Pattern
(
epsilon
,
...
...
vllm/compilation/fusion.py
View file @
db74d604
...
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.platforms
import
current_platform
from
.fx_utils
import
find_getitem_maybe
from
.inductor_pass
import
enable_fake_mode
from
.multi_output_match
import
MultiOutputMatch
from
.vllm_inductor_pass
import
VllmInductorPass
...
...
@@ -528,6 +529,7 @@ class FusionPass(VllmInductorPass):
cls
.
_instance
.
pass_config
=
config
.
compilation_config
.
pass_config
return
cls
.
_instance
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
assert
self
.
__class__
.
_instance
is
None
,
\
"FusionPass singleton instance already exists"
...
...
vllm/compilation/fusion_attn.py
View file @
db74d604
...
...
@@ -7,8 +7,6 @@ import torch
import
torch._inductor.pattern_matcher
as
pm
from
torch._higher_order_ops.auto_functionalize
import
auto_functionalized
from
torch._inductor.pattern_matcher
import
PatternMatcherPass
from
torch._subclasses.fake_tensor
import
(
FakeTensorMode
,
unset_fake_temporarily
)
from
vllm.attention
import
Attention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
...
...
@@ -19,6 +17,7 @@ from vllm.platforms import current_platform
from
vllm.utils
import
round_up
from
.fusion
import
QUANT_OPS
,
empty_bf16
,
empty_fp32
,
empty_i32
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
...
...
@@ -139,16 +138,13 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
output_block_scale
=
None
)
return
RESHAPE_OP
(
at1
[
1
],
[
-
1
,
self
.
num_heads
*
self
.
head_size
])
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with
unset_fake_temporarily
(),
FakeTensorMode
():
inputs
=
[
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# q
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# k
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# v
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# attn_output
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
),
# quant_output
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
),
# quant_output
empty_fp32
(
1
,
1
)
# scale
]
...
...
@@ -219,9 +215,6 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
[
-
1
,
self
.
num_heads
*
self
.
head_size
//
2
])
return
output
,
at2
[
2
]
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with
unset_fake_temporarily
(),
FakeTensorMode
():
inputs
=
[
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# q
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# k
...
...
@@ -229,8 +222,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
empty_bf16
(
5
,
self
.
num_heads
,
self
.
head_size
),
# output_attn
self
.
empty_quant
(
5
,
self
.
num_heads
*
self
.
head_size
//
2
),
# output_quant
empty_i32
(
128
,
round_up
(
self
.
num_heads
*
self
.
head_size
//
16
,
empty_i32
(
128
,
round_up
(
self
.
num_heads
*
self
.
head_size
//
16
,
4
)),
# output_scale
empty_fp32
(
1
,
1
),
# input_scale
]
...
...
@@ -255,6 +247,7 @@ class AttnFusionPass(VllmInductorPass):
support are attention kernels, which need to support fusing output quant.
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
...
...
vllm/compilation/inductor_pass.py
View file @
db74d604
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
hashlib
import
inspect
import
json
...
...
@@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
import
torch
from
torch
import
fx
from
torch._subclasses.fake_tensor
import
(
FakeTensorMode
,
unset_fake_temporarily
)
from
vllm.utils
import
is_torch_equal_or_newer
...
...
@@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
def
uuid
(
self
)
->
Any
:
return
self
.
_uuid
def
enable_fake_mode
(
fn
:
Callable
[...,
Any
])
->
Callable
[...,
Any
]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@
functools
.
wraps
(
fn
)
def
fn_new
(
*
args
,
**
kwargs
)
->
Any
:
with
torch
.
_guards
.
tracing
(
None
),
unset_fake_temporarily
(),
FakeTensorMode
():
result
=
fn
(
*
args
,
**
kwargs
)
return
result
return
fn_new
vllm/compilation/sequence_parallelism.py
View file @
db74d604
...
...
@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.inductor_pass
import
enable_fake_mode
from
.vllm_inductor_pass
import
VllmInductorPass
logger
=
init_logger
(
__name__
)
...
...
@@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
performance.
"""
@
enable_fake_mode
def
__init__
(
self
,
config
:
VllmConfig
):
super
().
__init__
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment