"...text-generation-inference.git" did not exist on "31d76e238df7654157ab1e372b7d57ef859daaa7"
Unverified Commit 1e216be8 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

make scaling factor a config arg of vae/vqvae (#1860)



* make scaling factor cnfig arg of vae

* fix

* make flake happy

* fix ldm

* fix upscaler

* qualirty

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* solve conflicts, addres some comments

* examples

* examples min version

* doc

* fix type

* typo

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* remove duplicate line

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 915a5636
...@@ -150,7 +150,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -150,7 +150,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
else: else:
raise ValueError(f"scheduler type {type(self.scheduler)} not supported") raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
sample = 1 / 0.18215 * sample sample = 1 / self.vae.config.scaling_factor * sample
image = self.vae.decode(sample).sample image = self.vae.decode(sample).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
...@@ -336,7 +336,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -336,7 +336,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
......
...@@ -803,7 +803,7 @@ def main(args): ...@@ -803,7 +803,7 @@ def main(args):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
......
...@@ -533,7 +533,7 @@ def main(): ...@@ -533,7 +533,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng) latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW) # (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2)) latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng) noise_rng, timestep_rng = jax.random.split(sample_rng)
......
...@@ -853,7 +853,7 @@ def main(args): ...@@ -853,7 +853,7 @@ def main(args):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
......
...@@ -607,7 +607,7 @@ def main(args): ...@@ -607,7 +607,7 @@ def main(args):
optimizer.zero_grad() optimizer.zero_grad()
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
......
...@@ -33,7 +33,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -33,7 +33,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0") check_min_version("0.13.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -699,13 +699,13 @@ def main(): ...@@ -699,13 +699,13 @@ def main():
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Convert masked images to latent space # Convert masked images to latent space
masked_latents = vae.encode( masked_latents = vae.encode(
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
).latent_dist.sample() ).latent_dist.sample()
masked_latents = masked_latents * 0.18215 masked_latents = masked_latents * vae.config.scaling_factor
masks = batch["masks"] masks = batch["masks"]
# resize the mask to latents shape as we concatenate the mask to the latents # resize the mask to latents shape as we concatenate the mask to the latents
......
...@@ -51,7 +51,7 @@ else: ...@@ -51,7 +51,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0") check_min_version("0.13.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -555,7 +555,7 @@ def main(): ...@@ -555,7 +555,7 @@ def main():
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device) noise = torch.randn(latents.shape).to(latents.device)
......
...@@ -31,7 +31,7 @@ from transformers import AutoTokenizer, PretrainedConfig ...@@ -31,7 +31,7 @@ from transformers import AutoTokenizer, PretrainedConfig
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0") check_min_version("0.13.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -788,7 +788,7 @@ def main(args): ...@@ -788,7 +788,7 @@ def main(args):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
......
...@@ -636,7 +636,7 @@ def main(): ...@@ -636,7 +636,7 @@ def main():
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
......
...@@ -438,7 +438,7 @@ def main(): ...@@ -438,7 +438,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng) latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW) # (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2)) latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng) noise_rng, timestep_rng = jax.random.split(sample_rng)
......
...@@ -689,7 +689,7 @@ def main(): ...@@ -689,7 +689,7 @@ def main():
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
......
...@@ -711,7 +711,7 @@ def main(): ...@@ -711,7 +711,7 @@ def main():
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
......
...@@ -525,7 +525,7 @@ def main(): ...@@ -525,7 +525,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng) latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW) # (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2)) latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215 latents = latents * vae.config.scaling_factor
noise_rng, timestep_rng = jax.random.split(sample_rng) noise_rng, timestep_rng = jax.random.split(sample_rng)
noise = jax.random.normal(noise_rng, latents.shape) noise = jax.random.normal(noise_rng, latents.shape)
......
...@@ -54,8 +54,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -54,8 +54,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
block_out_channels (`Tuple[int]`, *optional*, defaults to : block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels. obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO sample_size (`int`, *optional*, defaults to `32`): TODO
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
""" """
@register_to_config @register_to_config
...@@ -71,6 +78,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -71,6 +78,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
latent_channels: int = 4, latent_channels: int = 4,
norm_num_groups: int = 32, norm_num_groups: int = 32,
sample_size: int = 32, sample_size: int = 32,
scaling_factor: float = 0.18215,
): ):
super().__init__() super().__init__()
......
...@@ -752,8 +752,15 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -752,8 +752,15 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
Latent space channels Latent space channels
norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
Norm num group Norm num group
sample_size (:obj:`int`, *optional*, defaults to `32`): sample_size (:obj:`int`, *optional*, defaults to 32):
Sample input size Sample input size
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype` parameters `dtype`
""" """
...@@ -767,6 +774,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -767,6 +774,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
latent_channels: int = 4 latent_channels: int = 4
norm_num_groups: int = 32 norm_num_groups: int = 32
sample_size: int = 32 sample_size: int = 32
scaling_factor: float = 0.18215
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
......
...@@ -57,6 +57,13 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -57,6 +57,13 @@ class VQModel(ModelMixin, ConfigMixin):
sample_size (`int`, *optional*, defaults to `32`): TODO sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
""" """
@register_to_config @register_to_config
...@@ -74,6 +81,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -74,6 +81,7 @@ class VQModel(ModelMixin, ConfigMixin):
num_vq_embeddings: int = 256, num_vq_embeddings: int = 256,
norm_num_groups: int = 32, norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None, vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215,
): ):
super().__init__() super().__init__()
......
...@@ -369,7 +369,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -369,7 +369,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
......
...@@ -391,7 +391,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -391,7 +391,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return image, has_nsfw_concept return image, has_nsfw_concept
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
...@@ -490,7 +490,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -490,7 +490,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size # expand init_latents for batch_size
......
...@@ -153,7 +153,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -153,7 +153,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
generator=generator generator=generator
)[0] )[0]
input_images = 0.18215 * input_images input_images = self.vqvae.config.scaling_factor * input_images
if start_step > 0: if start_step > 0:
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
...@@ -195,7 +195,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -195,7 +195,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
if self.vqvae is not None: if self.vqvae is not None:
# 0.18215 was scaling factor used in training to ensure unit variance # 0.18215 was scaling factor used in training to ensure unit variance
images = 1 / 0.18215 * images images = 1 / self.vqvae.config.scaling_factor * images
images = self.vqvae.decode(images)["sample"] images = self.vqvae.decode(images)["sample"]
images = (images / 2 + 0.5).clamp(0, 1) images = (images / 2 + 0.5).clamp(0, 1)
......
...@@ -182,7 +182,7 @@ class DiTPipeline(DiffusionPipeline): ...@@ -182,7 +182,7 @@ class DiTPipeline(DiffusionPipeline):
else: else:
latents = latent_model_input latents = latent_model_input
latents = 1 / 0.18215 * latents latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample samples = self.vae.decode(latents).sample
samples = (samples / 2 + 0.5).clamp(0, 1) samples = (samples / 2 + 0.5).clamp(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment