Commit cc04b3fb authored by helloyongyang's avatar helloyongyang
Browse files

update seq_parallel func

parent 6ed39fbf
...@@ -388,7 +388,7 @@ class WanModel: ...@@ -388,7 +388,7 @@ class WanModel:
if padding_size > 0: if padding_size > 0:
x = F.pad(x, (0, 0, 0, padding_size)) x = F.pad(x, (0, 0, 0, padding_size))
x = torch.chunk(x, world_size, dim=0)[cur_rank] pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0 embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
...@@ -398,12 +398,8 @@ class WanModel: ...@@ -398,12 +398,8 @@ class WanModel:
embed = F.pad(embed, (0, 0, 0, padding_size)) embed = F.pad(embed, (0, 0, 0, padding_size))
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size))
embed = torch.chunk(embed, world_size, dim=0)[cur_rank] pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank]
embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank] pre_infer_out.embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank]
pre_infer_out.embed = embed
pre_infer_out.embed0 = embed0
pre_infer_out.x = x
return pre_infer_out return pre_infer_out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment