Commit 7d109a7c authored by helloyongyang's avatar helloyongyang
Browse files

Simplify‌ wan pre infer & Remove seq_len in WanScheduler

parent 68a807f1
...@@ -27,7 +27,7 @@ class WanPreInfer: ...@@ -27,7 +27,7 @@ class WanPreInfer:
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, weights, inputs, positive, kv_start=0, kv_end=0): def infer(self, weights, inputs, positive, kv_start=0, kv_end=0):
x = [self.scheduler.latents] x = self.scheduler.latents
if self.scheduler.flag_df: if self.scheduler.flag_df:
t = self.scheduler.df_timesteps[self.scheduler.step_index].unsqueeze(0) t = self.scheduler.df_timesteps[self.scheduler.step_index].unsqueeze(0)
...@@ -39,7 +39,6 @@ class WanPreInfer: ...@@ -39,7 +39,6 @@ class WanPreInfer:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
if self.task == "i2v": if self.task == "i2v":
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
...@@ -50,16 +49,14 @@ class WanPreInfer: ...@@ -50,16 +49,14 @@ class WanPreInfer:
idx_s = kv_start // frame_seq_length idx_s = kv_start // frame_seq_length
idx_e = kv_end // frame_seq_length idx_e = kv_end // frame_seq_length
image_encoder = image_encoder[:, idx_s:idx_e, :, :] image_encoder = image_encoder[:, idx_s:idx_e, :, :]
y = [image_encoder] y = image_encoder
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = torch.cat([x, y], dim=0)
# embeddings # embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x] x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long).unsqueeze(0)
x = [u.flatten(2).transpose(1, 2) for u in x] x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda() seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg: if self.enable_dynamic_cfg:
......
...@@ -27,11 +27,6 @@ class WanScheduler(BaseScheduler): ...@@ -27,11 +27,6 @@ class WanScheduler(BaseScheduler):
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment