"vscode:/vscode.git/clone" did not exist on "b73d3d6ed96f8459776b23d4fba1af746e643fa4"
Unverified Commit f7665abb authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Fix rope for parallel (#530)

parent 2479e81d
...@@ -2,6 +2,8 @@ from typing import List, Optional, Union ...@@ -2,6 +2,8 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
from torch.nn import functional as F
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.utils import masks_like from lightx2v.utils.utils import masks_like
...@@ -14,6 +16,10 @@ class WanScheduler(BaseScheduler): ...@@ -14,6 +16,10 @@ class WanScheduler(BaseScheduler):
self.infer_steps = self.config["infer_steps"] self.infer_steps = self.config["infer_steps"]
self.target_video_length = self.config["target_video_length"] self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config["sample_shift"] self.sample_shift = self.config["sample_shift"]
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.shift = 1 self.shift = 1
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
...@@ -87,8 +93,24 @@ class WanScheduler(BaseScheduler): ...@@ -87,8 +93,24 @@ class WanScheduler(BaseScheduler):
cos_half = cos_sin.real.contiguous() cos_half = cos_sin.real.contiguous()
sin_half = cos_sin.imag.contiguous() sin_half = cos_sin.imag.contiguous()
cos_sin = torch.cat([cos_half, sin_half], dim=-1) cos_sin = torch.cat([cos_half, sin_half], dim=-1)
if self.seq_p_group is not None:
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
seqlen = cos_sin.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
cos_sin = F.pad(cos_sin, (0, 0, 0, padding_size))
cos_sin = torch.chunk(cos_sin, world_size, dim=0)[cur_rank]
else: else:
cos_sin = cos_sin.reshape(seq_len, 1, -1) cos_sin = cos_sin.reshape(seq_len, 1, -1)
if self.seq_p_group is not None:
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
seqlen = cos_sin.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
cos_sin = F.pad(cos_sin, (0, 0, 0, 0, 0, padding_size))
cos_sin = torch.chunk(cos_sin, world_size, dim=0)[cur_rank]
return cos_sin return cos_sin
def prepare_latents(self, seed, latent_shape, dtype=torch.float32): def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment