Commit 7d109a7c authored by helloyongyang's avatar helloyongyang
Browse files

Simplify‌ wan pre infer & Remove seq_len in WanScheduler

parent 68a807f1
......@@ -27,7 +27,7 @@ class WanPreInfer:
self.scheduler = scheduler
def infer(self, weights, inputs, positive, kv_start=0, kv_end=0):
x = [self.scheduler.latents]
x = self.scheduler.latents
if self.scheduler.flag_df:
t = self.scheduler.df_timesteps[self.scheduler.step_index].unsqueeze(0)
......@@ -39,7 +39,6 @@ class WanPreInfer:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
if self.task == "i2v":
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
......@@ -50,16 +49,14 @@ class WanPreInfer:
idx_s = kv_start // frame_seq_length
idx_e = kv_end // frame_seq_length
image_encoder = image_encoder[:, idx_s:idx_e, :, :]
y = [image_encoder]
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
y = image_encoder
x = torch.cat([x, y], dim=0)
# embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long).unsqueeze(0)
x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg:
......
......@@ -27,11 +27,6 @@ class WanScheduler(BaseScheduler):
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment