Unverified Commit e550163b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Vae] Make sure all vae's work with latent diffusion models (#5880)

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more
parent 20f0cbc8
...@@ -108,6 +108,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -108,6 +108,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.use_slicing = False self.use_slicing = False
self.use_tiling = False self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
@apply_forward_hook @apply_forward_hook
def encode( def encode(
self, x: torch.FloatTensor, return_dict: bool = True self, x: torch.FloatTensor, return_dict: bool = True
......
...@@ -148,6 +148,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -148,6 +148,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self.tile_sample_min_size = 512 self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
self.register_to_config(block_out_channels=decoder_block_out_channels)
self.register_to_config(force_upcast=False)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (EncoderTiny, DecoderTiny)): if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
......
...@@ -138,6 +138,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -138,6 +138,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
) )
self.decoder_scheduler = ConsistencyDecoderScheduler() self.decoder_scheduler = ConsistencyDecoderScheduler()
self.register_to_config(block_out_channels=encoder_block_out_channels) self.register_to_config(block_out_channels=encoder_block_out_channels)
self.register_to_config(force_upcast=False)
self.register_buffer( self.register_buffer(
"means", "means",
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
......
...@@ -76,9 +76,13 @@ EXAMPLE_DOC_STRING = """ ...@@ -76,9 +76,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -92,9 +92,13 @@ EXAMPLE_DOC_STRING = """ ...@@ -92,9 +92,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -104,9 +104,13 @@ EXAMPLE_DOC_STRING = """ ...@@ -104,9 +104,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -54,6 +54,20 @@ if is_invisible_watermark_available(): ...@@ -54,6 +54,20 @@ if is_invisible_watermark_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -824,12 +838,12 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -824,12 +838,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
if isinstance(generator, list): if isinstance(generator, list):
image_latents = [ image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0]) for i in range(image.shape[0])
] ]
image_latents = torch.cat(image_latents, dim=0) image_latents = torch.cat(image_latents, dim=0)
else: else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast: if self.vae.config.force_upcast:
self.vae.to(dtype) self.vae.to(dtype)
......
...@@ -133,9 +133,13 @@ EXAMPLE_DOC_STRING = """ ...@@ -133,9 +133,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -44,9 +44,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -44,9 +44,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -35,9 +35,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -35,9 +35,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -61,6 +61,20 @@ def preprocess(image): ...@@ -61,6 +61,20 @@ def preprocess(image):
return image return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
# 1. get previous step value (=t-1) # 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
...@@ -567,11 +581,12 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -567,11 +581,12 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if isinstance(generator, list): if isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
......
...@@ -37,9 +37,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -37,9 +37,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -73,9 +73,13 @@ EXAMPLE_DOC_STRING = """ ...@@ -73,9 +73,13 @@ EXAMPLE_DOC_STRING = """
""" """
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -160,9 +160,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -160,9 +160,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -58,6 +58,20 @@ def preprocess(image): ...@@ -58,6 +58,20 @@ def preprocess(image):
return image return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
...@@ -320,7 +334,6 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -320,7 +334,6 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
self.do_classifier_free_guidance, self.do_classifier_free_guidance,
generator,
) )
height, width = image_latents.shape[-2:] height, width = image_latents.shape[-2:]
...@@ -716,17 +729,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -716,17 +729,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
if image.shape[1] == 4: if image.shape[1] == 4:
image_latents = image image_latents = image
else: else:
if isinstance(generator, list) and len(generator) != batch_size: image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.mode()
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size # expand image_latents for batch_size
......
...@@ -105,9 +105,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): ...@@ -105,9 +105,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -250,9 +250,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -250,9 +250,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator): def retrieve_latents(
if hasattr(encoder_output, "latent_dist"): encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator) return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"): elif hasattr(encoder_output, "latents"):
return encoder_output.latents return encoder_output.latents
else: else:
......
...@@ -88,6 +88,20 @@ EXAMPLE_DOC_STRING = """ ...@@ -88,6 +88,20 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" """
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
...@@ -533,17 +547,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -533,17 +547,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
self.upcast_vae() self.upcast_vae()
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if isinstance(generator, list) and len(generator) != batch_size: image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.mode()
# cast back to fp16 if needed # cast back to fp16 if needed
if needs_upcasting: if needs_upcasting:
...@@ -866,7 +870,6 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -866,7 +870,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
prompt_embeds.dtype, prompt_embeds.dtype,
device, device,
do_classifier_free_guidance, do_classifier_free_guidance,
generator,
) )
# 7. Prepare latent variables # 7. Prepare latent variables
......
...@@ -79,6 +79,20 @@ EXAMPLE_DOC_STRING = """ ...@@ -79,6 +79,20 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
# reshape to ncfhw # reshape to ncfhw
...@@ -547,14 +561,14 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -547,14 +561,14 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
elif isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i])
for i in range(batch_size)
] ]
init_latents = torch.cat(init_latents, dim=0) init_latents = torch.cat(init_latents, dim=0)
else: else:
init_latents = self.vae.encode(video).latent_dist.sample(generator) init_latents = retrieve_latents(self.vae.encode(video), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents init_latents = self.vae.config.scaling_factor * init_latents
......
...@@ -46,6 +46,82 @@ from .test_modeling_common import ModelTesterMixin, UNetTesterMixin ...@@ -46,6 +46,82 @@ from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism() enable_full_determinism()
def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
init_dict = {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
}
return init_dict
def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
init_dict = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"down_block_out_channels": block_out_channels,
"layers_per_down_block": 1,
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
"up_block_out_channels": block_out_channels,
"layers_per_up_block": 1,
"act_fn": "silu",
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
"sample_size": 32,
"scaling_factor": 0.18215,
}
return init_dict
def get_autoencoder_tiny_config(block_out_channels=None):
block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
init_dict = {
"in_channels": 3,
"out_channels": 3,
"encoder_block_out_channels": block_out_channels,
"decoder_block_out_channels": block_out_channels,
"num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
"num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
}
return init_dict
def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [32, 64]
norm_num_groups = norm_num_groups or 32
return {
"encoder_block_out_channels": block_out_channels,
"encoder_in_channels": 3,
"encoder_out_channels": 4,
"encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
"decoder_add_attention": False,
"decoder_block_out_channels": block_out_channels,
"decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
"decoder_downsample_padding": 1,
"decoder_in_channels": 7,
"decoder_layers_per_block": 1,
"decoder_norm_eps": 1e-05,
"decoder_norm_num_groups": norm_num_groups,
"encoder_norm_num_groups": norm_num_groups,
"decoder_num_train_timesteps": 1024,
"decoder_out_channels": 6,
"decoder_resnet_time_scale_shift": "scale_shift",
"decoder_time_embedding_type": "learned",
"decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
"scaling_factor": 1,
"latent_channels": 4,
}
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKL model_class = AutoencoderKL
main_input_name = "sample" main_input_name = "sample"
...@@ -70,14 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -70,14 +146,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
return (3, 32, 32) return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = get_autoencoder_kl_config()
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -214,21 +283,7 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T ...@@ -214,21 +283,7 @@ class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.T
return (3, 32, 32) return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = get_asym_autoencoder_kl_config()
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"down_block_out_channels": [32, 64],
"layers_per_down_block": 1,
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"up_block_out_channels": [32, 64],
"layers_per_up_block": 1,
"act_fn": "silu",
"latent_channels": 4,
"norm_num_groups": 32,
"sample_size": 32,
"scaling_factor": 0.18215,
}
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -263,14 +318,7 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase): ...@@ -263,14 +318,7 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
return (3, 32, 32) return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = get_autoencoder_tiny_config()
"in_channels": 3,
"out_channels": 3,
"encoder_block_out_channels": (32, 32),
"decoder_block_out_channels": (32, 32),
"num_encoder_blocks": (1, 2),
"num_decoder_blocks": (2, 1),
}
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -302,33 +350,7 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): ...@@ -302,33 +350,7 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
@property @property
def init_dict(self): def init_dict(self):
return { return get_consistency_vae_config()
"encoder_block_out_channels": [32, 64],
"encoder_in_channels": 3,
"encoder_out_channels": 4,
"encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"decoder_add_attention": False,
"decoder_block_out_channels": [32, 64],
"decoder_down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
],
"decoder_downsample_padding": 1,
"decoder_in_channels": 7,
"decoder_layers_per_block": 1,
"decoder_norm_eps": 1e-05,
"decoder_norm_num_groups": 32,
"decoder_num_train_timesteps": 1024,
"decoder_out_channels": 6,
"decoder_resnet_time_scale_shift": "scale_shift",
"decoder_time_embedding_type": "learned",
"decoder_up_block_types": [
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"scaling_factor": 1,
"latent_channels": 4,
}
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict() return self.init_dict, self.inputs_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment