Commit 145161cd authored by gaclove's avatar gaclove
Browse files

fix: add size mismatch handling for prev_latents in VideoGenerator to ensure consistent dimensions

parent edeae441
......@@ -292,6 +292,15 @@ class VideoGenerator:
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(
prev_latents,
size=(height, width),
mode='bilinear',
align_corners=False
)
return {"prev_latents": prev_latents, "prev_mask": prev_mask}
......@@ -349,6 +358,15 @@ class VideoGenerator:
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(
prev_latents,
size=(height, width),
mode='bilinear',
align_corners=False
)
# Always set previmg_encoder_output
inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment