Commit 64948a2e authored by wangshankun's avatar wangshankun
Browse files

dist group统一从runner入口

parent 826e9b03
...@@ -7,9 +7,9 @@ from lightx2v.models.networks.wan.infer.utils import pad_freqs ...@@ -7,9 +7,9 @@ from lightx2v.models.networks.wan.infer.utils import pad_freqs
class WanTransformerDistInfer(WanTransformerInfer): class WanTransformerDistInfer(WanTransformerInfer):
def __init__(self, config): def __init__(self, config, seq_p_group=None):
super().__init__(config) super().__init__(config)
self.seq_p_group = self.config["device_mesh"].get_group(mesh_dim="seq_p") self.seq_p_group = seq_p_group
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
x, embed0 = self.dist_pre_process(x, embed0) x, embed0 = self.dist_pre_process(x, embed0)
......
...@@ -83,7 +83,7 @@ class WanModel: ...@@ -83,7 +83,7 @@ class WanModel:
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
if self.config["seq_parallel"]: if self.seq_p_group is not None:
self.transformer_infer_class = WanTransformerDistInfer self.transformer_infer_class = WanTransformerDistInfer
else: else:
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
...@@ -255,7 +255,12 @@ class WanModel: ...@@ -255,7 +255,12 @@ class WanModel:
def _init_infer(self): def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
if self.seq_p_group is not None:
self.transformer_infer = self.transformer_infer_class(self.config, self.seq_p_group)
else:
self.transformer_infer = self.transformer_infer_class(self.config)
if self.config["cfg_parallel"]: if self.config["cfg_parallel"]:
self.infer_func = self.infer_with_cfg_parallel self.infer_func = self.infer_with_cfg_parallel
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment