Unverified Commit 33c5d125 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] fix img2img pipeline for Playground (#7627)

* playground vae encoding should use std and mean of the vae.

* style.

* fix-copies.
parent aa1f00fd
...@@ -898,6 +898,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -898,6 +898,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
) )
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled # Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu") self.text_encoder_2.to("cpu")
...@@ -935,6 +941,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -935,6 +941,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self.vae.to(dtype) self.vae.to(dtype)
init_latents = init_latents.to(dtype) init_latents = init_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
latents_std = latents_std.to(device=self.device, dtype=dtype)
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
......
...@@ -665,6 +665,12 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -665,6 +665,12 @@ class StableDiffusionXLImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
) )
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled # Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu") self.text_encoder_2.to("cpu")
...@@ -702,6 +708,11 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -702,6 +708,11 @@ class StableDiffusionXLImg2ImgPipeline(
self.vae.to(dtype) self.vae.to(dtype)
init_latents = init_latents.to(dtype) init_latents = init_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
latents_std = latents_std.to(device=self.device, dtype=dtype)
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment