Unverified Commit 56b3b216 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Refactor autoencoders] feat: introduce autoencoders module (#6129)

* feat: introduce autoencoders module

* more changes for styling and copy fixing

* path changes in the docs.

* fix: import structure in init.

* fix controlnetxs import
parent 9cef07da
...@@ -49,12 +49,12 @@ make_image_grid([original_image, mask_image, image], rows=1, cols=3) ...@@ -49,12 +49,12 @@ make_image_grid([original_image, mask_image, image], rows=1, cols=3)
## AsymmetricAutoencoderKL ## AsymmetricAutoencoderKL
[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL [[autodoc]] models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL
## AutoencoderKLOutput ## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput [[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput ## DecoderOutput
[[autodoc]] models.vae.DecoderOutput [[autodoc]] models.autoencoders.vae.DecoderOutput
...@@ -54,4 +54,4 @@ image ...@@ -54,4 +54,4 @@ image
## AutoencoderTinyOutput ## AutoencoderTinyOutput
[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput [[autodoc]] models.autoencoders.autoencoder_tiny.AutoencoderTinyOutput
...@@ -36,11 +36,11 @@ model = AutoencoderKL.from_single_file(url) ...@@ -36,11 +36,11 @@ model = AutoencoderKL.from_single_file(url)
## AutoencoderKLOutput ## AutoencoderKLOutput
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput [[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput ## DecoderOutput
[[autodoc]] models.vae.DecoderOutput [[autodoc]] models.autoencoders.vae.DecoderOutput
## FlaxAutoencoderKL ## FlaxAutoencoderKL
......
...@@ -12,9 +12,9 @@ from safetensors.torch import load_file as stl ...@@ -12,9 +12,9 @@ from safetensors.torch import load_file as stl
from tqdm import tqdm from tqdm import tqdm
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
from diffusers.models.autoencoders.vae import Encoder
from diffusers.models.embeddings import TimestepEmbedding from diffusers.models.embeddings import TimestepEmbedding
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
from diffusers.models.vae import Encoder
args = ArgumentParser() args = ArgumentParser()
......
...@@ -26,11 +26,11 @@ _import_structure = {} ...@@ -26,11 +26,11 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"] _import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnetxs"] = ["ControlNetXSModel"] _import_structure["controlnetxs"] = ["ControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
...@@ -58,11 +58,13 @@ if is_flax_available(): ...@@ -58,11 +58,13 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available(): if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoders import (
from .autoencoder_kl import AutoencoderKL AsymmetricAutoencoderKL,
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder AutoencoderKL,
from .autoencoder_tiny import AutoencoderTiny AutoencoderKLTemporalDecoder,
from .consistency_decoder_vae import ConsistencyDecoderVAE AutoencoderTiny,
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel from .controlnet import ControlNetModel
from .controlnetxs import ControlNetXSModel from .controlnetxs import ControlNetXSModel
from .dual_transformer_2d import DualTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel
......
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
...@@ -16,10 +16,10 @@ from typing import Optional, Tuple, Union ...@@ -16,10 +16,10 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from .modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
......
...@@ -16,10 +16,10 @@ from typing import Dict, Optional, Tuple, Union ...@@ -16,10 +16,10 @@ from typing import Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin from ...loaders import FromOriginalVAEMixin
from ..utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from .attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention, Attention,
...@@ -27,8 +27,8 @@ from .attention_processor import ( ...@@ -27,8 +27,8 @@ from .attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from .modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
......
...@@ -16,14 +16,14 @@ from typing import Dict, Optional, Tuple, Union ...@@ -16,14 +16,14 @@ from typing import Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin from ...loaders import FromOriginalVAEMixin
from ..utils import is_torch_version from ...utils import is_torch_version
from ..utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
......
...@@ -18,10 +18,10 @@ from typing import Optional, Tuple, Union ...@@ -18,10 +18,10 @@ from typing import Optional, Tuple, Union
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ...utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DecoderTiny, EncoderTiny from .vae import DecoderOutput, DecoderTiny, EncoderTiny
......
...@@ -18,20 +18,20 @@ import torch ...@@ -18,20 +18,20 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..schedulers import ConsistencyDecoderScheduler from ...schedulers import ConsistencyDecoderScheduler
from ..utils import BaseOutput from ...utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ..utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from .attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .unet_2d import UNet2DModel from ..unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
...@@ -153,7 +153,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -153,7 +153,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.use_slicing = False self.use_slicing = False
self.use_tiling = False self.use_tiling = False
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True): def enable_tiling(self, use_tiling: bool = True):
r""" r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
...@@ -162,7 +162,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -162,7 +162,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
""" """
self.use_tiling = use_tiling self.use_tiling = use_tiling
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self): def disable_tiling(self):
r""" r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
...@@ -170,7 +170,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -170,7 +170,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
""" """
self.enable_tiling(False) self.enable_tiling(False)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self): def enable_slicing(self):
r""" r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
...@@ -178,7 +178,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -178,7 +178,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
""" """
self.use_slicing = True self.use_slicing = True
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self): def disable_slicing(self):
r""" r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
...@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return DecoderOutput(sample=x_0) return DecoderOutput(sample=x_0)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent) blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent): for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b return b
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent) blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent): for x in range(blend_extent):
......
...@@ -18,11 +18,11 @@ import numpy as np ...@@ -18,11 +18,11 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..utils import BaseOutput, is_torch_version from ...utils import BaseOutput, is_torch_version
from ..utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from .activations import get_activation from ..activations import get_activation
from .attention_processor import SpatialNorm from ..attention_processor import SpatialNorm
from .unet_2d_blocks import ( from ..unet_2d_blocks import (
AutoencoderTinyBlock, AutoencoderTinyBlock,
UNetMidBlock2D, UNetMidBlock2D,
get_down_block, get_down_block,
......
...@@ -26,7 +26,7 @@ from ..utils import BaseOutput, logging ...@@ -26,7 +26,7 @@ from ..utils import BaseOutput, logging
from .attention_processor import ( from .attention_processor import (
AttentionProcessor, AttentionProcessor,
) )
from .autoencoder_kl import AutoencoderKL from .autoencoders import AutoencoderKL
from .lora import LoRACompatibleConv from .lora import LoRACompatibleConv
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
......
...@@ -20,8 +20,8 @@ import torch.nn as nn ...@@ -20,8 +20,8 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook from ..utils.accelerate_utils import apply_forward_hook
from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
@dataclass @dataclass
......
...@@ -19,8 +19,8 @@ import torch ...@@ -19,8 +19,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.vae import DecoderOutput, VectorQuantizer
from ...models.vq_model import VQEncoderOutput from ...models.vq_model import VQEncoderOutput
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment