Unverified Commit 843e3f93 authored by Ömer Karışman's avatar Ömer Karışman Committed by GitHub
Browse files

wan2.2 i2v FirstBlockCache fix (#12013)

* enable caching for WanImageToVideoPipeline

* ruff format
parent d8854b8d
...@@ -750,6 +750,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -750,6 +750,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0]) timestep = t.expand(latents.shape[0])
with current_model.cache_context("cond"):
noise_pred = current_model( noise_pred = current_model(
hidden_states=latent_model_input, hidden_states=latent_model_input,
timestep=timestep, timestep=timestep,
...@@ -760,6 +761,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): ...@@ -760,6 +761,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)[0] )[0]
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
with current_model.cache_context("uncond"):
noise_uncond = current_model( noise_uncond = current_model(
hidden_states=latent_model_input, hidden_states=latent_model_input,
timestep=timestep, timestep=timestep,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment