Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d1432712
Unverified
Commit
d1432712
authored
Dec 08, 2025
by
Jiangyun Zhu
Committed by
GitHub
Dec 08, 2025
Browse files
[Bugfix] fix fuse_allreduce_rms when tp =1 (#30178)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
c6df05eb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
0 deletions
+7
-0
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+7
-0
No files found.
vllm/compilation/collective_fusion.py
View file @
d1432712
...
@@ -1076,11 +1076,15 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
...
@@ -1076,11 +1076,15 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self
.
disabled
=
True
self
.
disabled
=
True
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
<=
1
:
if
self
.
tp_size
<=
1
:
logger
.
warning_once
(
"AllReduce fusion pass is disabled for tp_size <= 1."
)
return
return
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"all_reduce_fusion_pass"
pass_name
=
"all_reduce_fusion_pass"
)
)
if
config
.
model_config
is
None
:
if
config
.
model_config
is
None
:
logger
.
warning_once
(
"AllReduce fusion pass is disabled for missing model_config."
)
return
return
self
.
hidden_dim
=
config
.
model_config
.
get_hidden_size
()
self
.
hidden_dim
=
config
.
model_config
.
get_hidden_size
()
self
.
group
=
get_tp_group
().
device_group
self
.
group
=
get_tp_group
().
device_group
...
@@ -1188,6 +1192,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
...
@@ -1188,6 +1192,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self
.
disabled
=
False
self
.
disabled
=
False
def
is_applicable_for_range
(
self
,
compile_range
:
Range
)
->
bool
:
def
is_applicable_for_range
(
self
,
compile_range
:
Range
)
->
bool
:
if
self
.
disabled
:
logger
.
warning_once
(
"AllReduce fusion pass is disabled."
)
return
False
return
compile_range
.
end
<=
self
.
max_token_num
return
compile_range
.
end
<=
self
.
max_token_num
@
VllmInductorPass
.
time_and_log
@
VllmInductorPass
.
time_and_log
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment