Unverified Commit d8854b8d authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[wan2.2] add 5b i2v (#12006)



* add 5b ti2v

* remove a copy

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Co-authored-by: default avatarAryan <aryan@huggingface.co>

* Apply suggestions from code review

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 327e251b
...@@ -370,7 +370,6 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixi ...@@ -370,7 +370,6 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixi
): ):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, self,
image: PipelineImageInput, image: PipelineImageInput,
......
...@@ -175,6 +175,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -175,6 +175,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
image_encoder: CLIPVisionModel = None, image_encoder: CLIPVisionModel = None,
transformer_2: WanTransformer3DModel = None, transformer_2: WanTransformer3DModel = None,
boundary_ratio: Optional[float] = None, boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False,
): ):
super().__init__() super().__init__()
...@@ -188,10 +189,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -188,10 +189,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
image_processor=image_processor, image_processor=image_processor,
transformer_2=transformer_2, transformer_2=transformer_2,
) )
self.register_to_config(boundary_ratio=boundary_ratio) self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.image_processor = image_processor self.image_processor = image_processor
...@@ -419,8 +420,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -419,8 +420,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
else: else:
latents = latents.to(device=device, dtype=dtype) latents = latents.to(device=device, dtype=dtype)
image = image.unsqueeze(2) image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
if last_image is None:
if self.config.expand_timesteps:
video_condition = image
elif last_image is None:
video_condition = torch.cat( video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
) )
...@@ -453,6 +458,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -453,6 +458,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = latent_condition.to(dtype) latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std latent_condition = (latent_condition - latents_mean) * latents_std
if self.config.expand_timesteps:
first_frame_mask = torch.ones(
1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
)
first_frame_mask[:, :, 0] = 0
return latents, latent_condition, first_frame_mask
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
if last_image is None: if last_image is None:
...@@ -662,7 +674,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -662,7 +674,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if negative_prompt_embeds is not None: if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
if self.config.boundary_ratio is None: if self.config.boundary_ratio is None and not self.config.expand_timesteps:
if image_embeds is None: if image_embeds is None:
if last_image is None: if last_image is None:
image_embeds = self.encode_image(image, device) image_embeds = self.encode_image(image, device)
...@@ -682,7 +694,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -682,7 +694,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
device, dtype=torch.float32 device, dtype=torch.float32
) )
latents, condition = self.prepare_latents(
latents_outputs = self.prepare_latents(
image, image,
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
num_channels_latents, num_channels_latents,
...@@ -695,6 +708,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -695,6 +708,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents, latents,
last_image, last_image,
) )
if self.config.expand_timesteps:
latents, condition, first_frame_mask = latents_outputs
else:
latents, condition = latents_outputs
# 6. Denoising loop # 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
...@@ -721,8 +738,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -721,8 +738,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
current_model = self.transformer_2 current_model = self.transformer_2
current_guidance_scale = guidance_scale_2 current_guidance_scale = guidance_scale_2
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) if self.config.expand_timesteps:
timestep = t.expand(latents.shape[0]) latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
latent_model_input = latent_model_input.to(transformer_dtype)
# seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = current_model( noise_pred = current_model(
hidden_states=latent_model_input, hidden_states=latent_model_input,
...@@ -766,6 +792,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -766,6 +792,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
self._current_timestep = None self._current_timestep = None
if self.config.expand_timesteps:
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
latents_mean = ( latents_mean = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment