Commit 64948a2e authored by wangshankun's avatar wangshankun
Browse files

dist group统一从runner入口

parent 826e9b03
......@@ -7,9 +7,9 @@ from lightx2v.models.networks.wan.infer.utils import pad_freqs
class WanTransformerDistInfer(WanTransformerInfer):
def __init__(self, config):
def __init__(self, config, seq_p_group=None):
super().__init__(config)
self.seq_p_group = self.config["device_mesh"].get_group(mesh_dim="seq_p")
self.seq_p_group = seq_p_group
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
x, embed0 = self.dist_pre_process(x, embed0)
......
......@@ -83,7 +83,7 @@ class WanModel:
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
if self.config["seq_parallel"]:
if self.seq_p_group is not None:
self.transformer_infer_class = WanTransformerDistInfer
else:
if self.config["feature_caching"] == "NoCaching":
......@@ -255,7 +255,12 @@ class WanModel:
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
if self.seq_p_group is not None:
self.transformer_infer = self.transformer_infer_class(self.config, self.seq_p_group)
else:
self.transformer_infer = self.transformer_infer_class(self.config)
if self.config["cfg_parallel"]:
self.infer_func = self.infer_with_cfg_parallel
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment