Commit 09d769cc authored by sandy's avatar sandy Committed by GitHub
Browse files

[Fix] audio_r2b_5b parallel (#252)

parent 2f0cfa56
...@@ -408,7 +408,7 @@ class WanModel: ...@@ -408,7 +408,7 @@ class WanModel:
pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank] pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v":
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0 embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
padding_size = (world_size - (embed.shape[0] % world_size)) % world_size padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment