Unverified Commit 1e216be8 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

make scaling factor a config arg of vae/vqvae (#1860)



* make scaling factor cnfig arg of vae

* fix

* make flake happy

* fix ldm

* fix upscaler

* qualirty

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* solve conflicts, addres some comments

* examples

* examples min version

* doc

* fix type

* typo

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* remove duplicate line

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 915a5636
......@@ -150,7 +150,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
else:
raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
sample = 1 / 0.18215 * sample
sample = 1 / self.vae.config.scaling_factor * sample
image = self.vae.decode(sample).sample
image = (image / 2 + 0.5).clamp(0, 1)
......@@ -336,7 +336,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
......
......@@ -803,7 +803,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
......
......@@ -533,7 +533,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng)
......
......@@ -853,7 +853,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
......
......@@ -607,7 +607,7 @@ def main(args):
optimizer.zero_grad()
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
......
......@@ -33,7 +33,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
......@@ -699,13 +699,13 @@ def main():
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215
masked_latents = masked_latents * vae.config.scaling_factor
masks = batch["masks"]
# resize the mask to latents shape as we concatenate the mask to the latents
......
......@@ -51,7 +51,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
......@@ -555,7 +555,7 @@ def main():
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
......
......@@ -31,7 +31,7 @@ from transformers import AutoTokenizer, PretrainedConfig
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
......@@ -788,7 +788,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
......
......@@ -636,7 +636,7 @@ def main():
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
......
......@@ -438,7 +438,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng)
......
......@@ -689,7 +689,7 @@ def main():
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
......
......@@ -711,7 +711,7 @@ def main():
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
......
......@@ -525,7 +525,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
noise_rng, timestep_rng = jax.random.split(sample_rng)
noise = jax.random.normal(noise_rng, latents.shape)
......
......@@ -54,8 +54,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
@register_to_config
......@@ -71,6 +78,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
):
super().__init__()
......
......@@ -752,8 +752,15 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
Latent space channels
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
Norm num group
sample_size (:obj:`int`, *optional*, defaults to `32`):
sample_size (:obj:`int`, *optional*, defaults to 32):
Sample input size
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype`
"""
......@@ -767,6 +774,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
latent_channels: int = 4
norm_num_groups: int = 32
sample_size: int = 32
scaling_factor: float = 0.18215
dtype: jnp.dtype = jnp.float32
def setup(self):
......
......@@ -57,6 +57,13 @@ class VQModel(ModelMixin, ConfigMixin):
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
@register_to_config
......@@ -74,6 +81,7 @@ class VQModel(ModelMixin, ConfigMixin):
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215,
):
super().__init__()
......
......@@ -369,7 +369,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
......@@ -391,7 +391,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......@@ -490,7 +490,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
......
......@@ -153,7 +153,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
generator=generator
)[0]
input_images = 0.18215 * input_images
input_images = self.vqvae.config.scaling_factor * input_images
if start_step > 0:
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
......@@ -195,7 +195,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
if self.vqvae is not None:
# 0.18215 was scaling factor used in training to ensure unit variance
images = 1 / 0.18215 * images
images = 1 / self.vqvae.config.scaling_factor * images
images = self.vqvae.decode(images)["sample"]
images = (images / 2 + 0.5).clamp(0, 1)
......
......@@ -182,7 +182,7 @@ class DiTPipeline(DiffusionPipeline):
else:
latents = latent_model_input
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample
samples = (samples / 2 + 0.5).clamp(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment