Unverified Commit 1e216be8 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

make scaling factor a config arg of vae/vqvae (#1860)



* make scaling factor cnfig arg of vae

* fix

* make flake happy

* fix ldm

* fix upscaler

* qualirty

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* solve conflicts, addres some comments

* examples

* examples min version

* doc

* fix type

* typo

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* remove duplicate line

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 915a5636
...@@ -182,7 +182,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -182,7 +182,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / self.vqvae.config.scaling_factor * latents
image = self.vqvae.decode(latents).sample image = self.vqvae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
......
...@@ -257,7 +257,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -257,7 +257,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
...@@ -328,7 +328,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -328,7 +328,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
masked_image_latents = torch.cat(masked_image_latents, dim=0) masked_image_latents = torch.cat(masked_image_latents, dim=0)
else: else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size: if mask.shape[0] < batch_size:
......
...@@ -474,7 +474,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -474,7 +474,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
...@@ -509,7 +509,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -509,7 +509,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size # expand init_latents for batch_size
......
...@@ -267,7 +267,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -267,7 +267,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
......
...@@ -224,7 +224,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -224,7 +224,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
# Create init_latents # Create init_latents
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
init_latents = 0.18215 * init_latents init_latents = self.vae.config.scaling_factor * init_latents
def loop_body(step, args): def loop_body(step, args):
latents, scheduler_state = args latents, scheduler_state = args
...@@ -272,7 +272,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -272,7 +272,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
......
...@@ -259,7 +259,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -259,7 +259,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
{"params": params["vae"]}, masked_image, method=self.vae.encode {"params": params["vae"]}, masked_image, method=self.vae.encode
).latent_dist ).latent_dist
masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2))
masked_image_latents = 0.18215 * masked_image_latents masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
del mask_prng_seed del mask_prng_seed
mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest")
...@@ -327,7 +327,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -327,7 +327,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
) )
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
......
...@@ -366,7 +366,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -366,7 +366,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -310,7 +310,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -310,7 +310,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
...@@ -413,7 +413,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -413,7 +413,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size # expand init_latents for batch_size
......
...@@ -195,7 +195,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -195,7 +195,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -400,7 +400,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -400,7 +400,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
...@@ -500,7 +500,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -500,7 +500,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size # expand init_latents for batch_size
......
...@@ -466,7 +466,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -466,7 +466,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
...@@ -561,7 +561,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -561,7 +561,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
masked_image_latents = torch.cat(masked_image_latents, dim=0) masked_image_latents = torch.cat(masked_image_latents, dim=0)
else: else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size: if mask.shape[0] < batch_size:
......
...@@ -367,7 +367,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -367,7 +367,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
...@@ -450,7 +450,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -450,7 +450,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
image = image.to(device=self.device, dtype=dtype) image = image.to(device=self.device, dtype=dtype)
init_latent_dist = self.vae.encode(image).latent_dist init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample(generator=generator) init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents init_latents = self.vae.config.scaling_factor * init_latents
# Expand init_latents for batch_size and num_images_per_prompt # Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
......
...@@ -588,7 +588,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -588,7 +588,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -313,7 +313,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -313,7 +313,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -89,6 +89,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -89,6 +89,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
): ):
super().__init__() super().__init__()
# check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
is_vae_scaling_factor_set_to_0_08333 = (
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
)
if not is_vae_scaling_factor_set_to_0_08333:
deprecation_message = (
"The configuration file of the vae does not contain `scaling_factor` or it is set to"
f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
" version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to 0.08333"
" Please make sure to update the config accordingly, as not doing so might lead to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be"
" very nice if you could open a Pull Request for the `vae/config.json` file"
)
deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)
vae.register_to_config(scaling_factor=0.08333)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -292,9 +308,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -292,9 +308,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.08333 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -364,7 +364,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -364,7 +364,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -330,7 +330,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -330,7 +330,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -190,7 +190,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -190,7 +190,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -247,7 +247,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -247,7 +247,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment