Unverified Commit 8b451eb6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix config prints and save, load of pipelines (#2849)

* [Config] Fix config prints and save, load

* Only use potential nn.Modules for dtype and device

* Correct vae image processor

* make sure in_channels is not accessed directly

* make sure in channels is only accessed via config

* Make sure schedulers only access config attributes

* Make sure to access config in SAG

* Fix vae processor and make style

* add tests

* uP

* make style

* Fix more naming issues

* Final fix with vae config

* change more
parent 83691967
...@@ -344,7 +344,7 @@ Now you can wrap all these components together in a training loop with 🤗 Acce ...@@ -344,7 +344,7 @@ Now you can wrap all these components together in a training loop with 🤗 Acce
... # Sample a random timestep for each image ... # Sample a random timestep for each image
... timesteps = torch.randint( ... timesteps = torch.randint(
... 0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device ... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
... ).long() ... ).long()
... # Add noise to the clean images according to the noise magnitude at each timestep ... # Add noise to the clean images according to the noise magnitude at each timestep
......
...@@ -62,7 +62,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline): ...@@ -62,7 +62,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __call__(self): def __call__(self):
image = torch.randn( image = torch.randn(
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
) )
timestep = 1 timestep = 1
...@@ -108,7 +108,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline): ...@@ -108,7 +108,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
def __call__(self): def __call__(self):
image = torch.randn( image = torch.randn(
(1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
) )
timestep = 1 timestep = 1
......
...@@ -89,7 +89,9 @@ class MyPipeline(DiffusionPipeline): ...@@ -89,7 +89,9 @@ class MyPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size: int = 1, num_inference_steps: int = 50): def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)) image = torch.randn(
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size)
)
image = image.to(self.device) image = image.to(self.device)
......
...@@ -238,7 +238,7 @@ class BitDiffusion(DiffusionPipeline): ...@@ -238,7 +238,7 @@ class BitDiffusion(DiffusionPipeline):
**kwargs, **kwargs,
) -> Union[Tuple, ImagePipelineOutput]: ) -> Union[Tuple, ImagePipelineOutput]:
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, height, width), (batch_size, self.unet.config.in_channels, height, width),
generator=generator, generator=generator,
) )
latents = decimal_to_bits(latents) * self.bit_scale latents = decimal_to_bits(latents) * self.bit_scale
......
...@@ -254,7 +254,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -254,7 +254,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
if self.device.type == "mps": if self.device.type == "mps":
......
...@@ -414,7 +414,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -414,7 +414,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
if self.device.type == "mps": if self.device.type == "mps":
......
...@@ -513,7 +513,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): ...@@ -513,7 +513,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -424,7 +424,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline): ...@@ -424,7 +424,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (1, self.unet.in_channels, height // 8, width // 8) latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if self.device.type == "mps": if self.device.type == "mps":
# randn does not exist on mps # randn does not exist on mps
......
...@@ -320,7 +320,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline): ...@@ -320,7 +320,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
if self.device.type == "mps": if self.device.type == "mps":
...@@ -416,7 +416,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline): ...@@ -416,7 +416,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
def get_noise(self, seed, dtype=torch.float32, height=512, width=512): def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
"""Takes in random seed and returns corresponding noise vector""" """Takes in random seed and returns corresponding noise vector"""
return torch.randn( return torch.randn(
(1, self.unet.in_channels, height // 8, width // 8), (1, self.unet.config.in_channels, height // 8, width // 8),
generator=torch.Generator(device=self.device).manual_seed(seed), generator=torch.Generator(device=self.device).manual_seed(seed),
device=self.device, device=self.device,
dtype=dtype, dtype=dtype,
......
...@@ -627,7 +627,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): ...@@ -627,7 +627,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
if image is None: if image is None:
shape = ( shape = (
batch_size, batch_size,
self.unet.in_channels, self.unet.config.in_channels,
height // self.vae_scale_factor, height // self.vae_scale_factor,
width // self.vae_scale_factor, width // self.vae_scale_factor,
) )
......
...@@ -486,7 +486,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline ...@@ -486,7 +486,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
self.__init__additional__() self.__init__additional__()
def __init__additional__(self): def __init__additional__(self):
self.unet_in_channels = 4 self.unet.config.in_channels = 4
self.vae_scale_factor = 8 self.vae_scale_factor = 8
def _encode_prompt( def _encode_prompt(
...@@ -621,7 +621,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline ...@@ -621,7 +621,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
if image is None: if image is None:
shape = ( shape = (
batch_size, batch_size,
self.unet_in_channels, self.unet.config.in_channels,
height // self.vae_scale_factor, height // self.vae_scale_factor,
width // self.vae_scale_factor, width // self.vae_scale_factor,
) )
......
...@@ -93,7 +93,7 @@ class MagicMixPipeline(DiffusionPipeline): ...@@ -93,7 +93,7 @@ class MagicMixPipeline(DiffusionPipeline):
torch.manual_seed(seed) torch.manual_seed(seed)
noise = torch.randn( noise = torch.randn(
(1, self.unet.in_channels, height // 8, width // 8), (1, self.unet.config.in_channels, height // 8, width // 8),
).to(self.device) ).to(self.device)
latents = self.scheduler.add_noise( latents = self.scheduler.add_noise(
......
...@@ -355,7 +355,7 @@ class MultilingualStableDiffusion(DiffusionPipeline): ...@@ -355,7 +355,7 @@ class MultilingualStableDiffusion(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
if self.device.type == "mps": if self.device.type == "mps":
......
...@@ -433,7 +433,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -433,7 +433,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
sigmas = sigmas.to(text_embeddings.dtype) sigmas = sigmas.to(text_embeddings.dtype)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
......
...@@ -262,8 +262,8 @@ class SeedResizeStableDiffusionPipeline(DiffusionPipeline): ...@@ -262,8 +262,8 @@ class SeedResizeStableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64) latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
if self.device.type == "mps": if self.device.type == "mps":
......
...@@ -190,7 +190,7 @@ class SpeechToImagePipeline(DiffusionPipeline): ...@@ -190,7 +190,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
if self.device.type == "mps": if self.device.type == "mps":
......
...@@ -337,7 +337,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline): ...@@ -337,7 +337,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
if self.device.type == "mps": if self.device.type == "mps":
......
...@@ -794,7 +794,7 @@ def main(): ...@@ -794,7 +794,7 @@ def main():
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
......
...@@ -794,7 +794,7 @@ def main(): ...@@ -794,7 +794,7 @@ def main():
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
......
...@@ -641,7 +641,7 @@ def main(): ...@@ -641,7 +641,7 @@ def main():
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment