Unverified Commit 843e3f93 authored by Ömer Karışman's avatar Ömer Karışman Committed by GitHub
Browse files

wan2.2 i2v FirstBlockCache fix (#12013)

* enable caching for WanImageToVideoPipeline

* ruff format
parent d8854b8d
......@@ -750,25 +750,27 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = current_model(
with current_model.cache_context("cond"):
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
if self.do_classifier_free_guidance:
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment