Unverified Commit d1432712 authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[Bugfix] fix fuse_allreduce_rms when tp =1 (#30178)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent c6df05eb
......@@ -1076,11 +1076,15 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size <= 1:
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
return
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="all_reduce_fusion_pass"
)
if config.model_config is None:
logger.warning_once(
"AllReduce fusion pass is disabled for missing model_config."
)
return
self.hidden_dim = config.model_config.get_hidden_size()
self.group = get_tp_group().device_group
......@@ -1188,6 +1192,9 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.disabled = False
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.")
return False
return compile_range.end <= self.max_token_num
@VllmInductorPass.time_and_log
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment