Commit 87bbed1c authored by helloyongyang's avatar helloyongyang
Browse files

fix ci

parent 145161cd
......@@ -292,15 +292,10 @@ class VideoGenerator:
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(
prev_latents,
size=(height, width),
mode='bilinear',
align_corners=False
)
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
return {"prev_latents": prev_latents, "prev_mask": prev_mask}
......@@ -358,15 +353,10 @@ class VideoGenerator:
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(
prev_latents,
size=(height, width),
mode='bilinear',
align_corners=False
)
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
# Always set previmg_encoder_output
inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment