Unverified Commit a5a0ccf8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[core] `AutoencoderMixin` to abstract common methods (#12473)

* up

* correct wording.

* up

* up

* up
parent dd07b19e
...@@ -20,10 +20,10 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -20,10 +20,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
KL loss for encoding images into latents and decoding latent representations into images. KL loss for encoding images into latents and decoding latent representations into images.
...@@ -107,9 +107,6 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -107,9 +107,6 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels) self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False) self.register_to_config(force_upcast=False)
......
...@@ -27,7 +27,7 @@ from ..attention_processor import SanaMultiscaleLinearAttention ...@@ -27,7 +27,7 @@ from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv from ..transformers.sana_transformer import GLUMBConv
from .vae import DecoderOutput, EncoderOutput from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
class ResBlock(nn.Module): class ResBlock(nn.Module):
...@@ -378,7 +378,7 @@ class Decoder(nn.Module): ...@@ -378,7 +378,7 @@ class Decoder(nn.Module):
return hidden_states return hidden_states
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r""" r"""
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
[SANA](https://huggingface.co/papers/2410.10629). [SANA](https://huggingface.co/papers/2410.10629).
...@@ -536,27 +536,6 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -536,27 +536,6 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor: def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape batch_size, num_channels, height, width = x.shape
......
...@@ -32,10 +32,10 @@ from ..attention_processor import ( ...@@ -32,10 +32,10 @@ from ..attention_processor import (
) )
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
...@@ -138,35 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter ...@@ -138,35 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25 self.tile_overlap_factor = 0.25
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property @property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
......
...@@ -28,6 +28,7 @@ from ..modeling_outputs import AutoencoderKLOutput ...@@ -28,6 +28,7 @@ from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D from ..upsampling import Upsample2D
from .vae import AutoencoderMixin
class AllegroTemporalConvLayer(nn.Module): class AllegroTemporalConvLayer(nn.Module):
...@@ -673,7 +674,7 @@ class AllegroDecoder3D(nn.Module): ...@@ -673,7 +674,7 @@ class AllegroDecoder3D(nn.Module):
return sample return sample
class AutoencoderKLAllegro(ModelMixin, ConfigMixin): class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro). [Allegro](https://github.com/rhymes-ai/Allegro).
...@@ -795,35 +796,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin): ...@@ -795,35 +796,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
sample_size - self.tile_overlap_w, sample_size - self.tile_overlap_w,
) )
def enable_tiling(self) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = True
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor: def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan) # TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
......
...@@ -29,7 +29,7 @@ from ..downsampling import CogVideoXDownsample3D ...@@ -29,7 +29,7 @@ from ..downsampling import CogVideoXDownsample3D
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D from ..upsampling import CogVideoXUpsample3D
from .vae import DecoderOutput, DiagonalGaussianDistribution from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -955,7 +955,7 @@ class CogVideoXDecoder3D(nn.Module): ...@@ -955,7 +955,7 @@ class CogVideoXDecoder3D(nn.Module):
return hidden_states, new_conv_cache return hidden_states, new_conv_cache
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo). [CogVideoX](https://github.com/THUDM/CogVideo).
...@@ -1124,27 +1124,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1124,27 +1124,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor: def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape batch_size, num_channels, num_frames, height, width = x.shape
......
...@@ -24,7 +24,7 @@ from ...utils import get_logger ...@@ -24,7 +24,7 @@ from ...utils import get_logger
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, IdentityDistribution from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -875,7 +875,7 @@ class CosmosDecoder3d(nn.Module): ...@@ -875,7 +875,7 @@ class CosmosDecoder3d(nn.Module):
return hidden_states return hidden_states
class AutoencoderKLCosmos(ModelMixin, ConfigMixin): class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575). Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
...@@ -1031,27 +1031,6 @@ class AutoencoderKLCosmos(ModelMixin, ConfigMixin): ...@@ -1031,27 +1031,6 @@ class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor: def _encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x) x = self.encoder(x)
enc = self.quant_conv(x) enc = self.quant_conv(x)
......
...@@ -26,7 +26,7 @@ from ..activations import get_activation ...@@ -26,7 +26,7 @@ from ..activations import get_activation
from ..attention_processor import Attention from ..attention_processor import Attention
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -624,7 +624,7 @@ class HunyuanVideoDecoder3D(nn.Module): ...@@ -624,7 +624,7 @@ class HunyuanVideoDecoder3D(nn.Module):
return hidden_states return hidden_states
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
...@@ -763,27 +763,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): ...@@ -763,27 +763,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor: def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape batch_size, num_channels, num_frames, height, width = x.shape
......
...@@ -26,7 +26,7 @@ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings ...@@ -26,7 +26,7 @@ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm from ..normalization import RMSNorm
from .vae import DecoderOutput, DiagonalGaussianDistribution from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class LTXVideoCausalConv3d(nn.Module): class LTXVideoCausalConv3d(nn.Module):
...@@ -1034,7 +1034,7 @@ class LTXVideoDecoder3d(nn.Module): ...@@ -1034,7 +1034,7 @@ class LTXVideoDecoder3d(nn.Module):
return hidden_states return hidden_states
class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video). [LTX](https://huggingface.co/Lightricks/LTX-Video).
...@@ -1219,27 +1219,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1219,27 +1219,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor: def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape batch_size, num_channels, num_frames, height, width = x.shape
......
...@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook ...@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -663,7 +663,7 @@ class EasyAnimateDecoder(nn.Module): ...@@ -663,7 +663,7 @@ class EasyAnimateDecoder(nn.Module):
return hidden_states return hidden_states
class AutoencoderKLMagvit(ModelMixin, ConfigMixin): class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991). model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
...@@ -805,27 +805,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin): ...@@ -805,27 +805,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook @apply_forward_hook
def _encode( def _encode(
self, x: torch.Tensor, return_dict: bool = True self, x: torch.Tensor, return_dict: bool = True
......
...@@ -27,7 +27,7 @@ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 ...@@ -27,7 +27,7 @@ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
from .vae import DecoderOutput, DiagonalGaussianDistribution from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -657,7 +657,7 @@ class MochiDecoder3D(nn.Module): ...@@ -657,7 +657,7 @@ class MochiDecoder3D(nn.Module):
return hidden_states, new_conv_cache return hidden_states, new_conv_cache
class AutoencoderKLMochi(ModelMixin, ConfigMixin): class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[Mochi 1 preview](https://github.com/genmoai/models). [Mochi 1 preview](https://github.com/genmoai/models).
...@@ -818,27 +818,6 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin): ...@@ -818,27 +818,6 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _enable_framewise_encoding(self): def _enable_framewise_encoding(self):
r""" r"""
Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
......
...@@ -31,7 +31,7 @@ from ...utils.accelerate_utils import apply_forward_hook ...@@ -31,7 +31,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -663,7 +663,7 @@ class QwenImageDecoder3d(nn.Module): ...@@ -663,7 +663,7 @@ class QwenImageDecoder3d(nn.Module):
return x return x
class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r""" r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
...@@ -763,27 +763,6 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -763,27 +763,6 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self): def clear_cache(self):
def _count_conv3d(model): def _count_conv3d(model):
count = 0 count = 0
......
...@@ -23,7 +23,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor ...@@ -23,7 +23,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module): class TemporalDecoder(nn.Module):
...@@ -135,7 +135,7 @@ class TemporalDecoder(nn.Module): ...@@ -135,7 +135,7 @@ class TemporalDecoder(nn.Module):
return sample return sample
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
......
...@@ -25,7 +25,7 @@ from ...utils.accelerate_utils import apply_forward_hook ...@@ -25,7 +25,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -951,7 +951,7 @@ def unpatchify(x, patch_size): ...@@ -951,7 +951,7 @@ def unpatchify(x, patch_size):
return x return x
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r""" r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1]. Introduced in [Wan 2.1].
...@@ -1110,27 +1110,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -1110,27 +1110,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self): def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"] self._conv_num = self._cached_conv_counts["decoder"]
......
...@@ -25,6 +25,7 @@ from ...utils import BaseOutput ...@@ -25,6 +25,7 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin
class Snake1d(nn.Module): class Snake1d(nn.Module):
...@@ -291,7 +292,7 @@ class OobleckDecoder(nn.Module): ...@@ -291,7 +292,7 @@ class OobleckDecoder(nn.Module):
return hidden_state return hidden_state
class AutoencoderOobleck(ModelMixin, ConfigMixin): class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio. introduced in Stable Audio.
...@@ -356,20 +357,6 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin): ...@@ -356,20 +357,6 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
self.use_slicing = False self.use_slicing = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook @apply_forward_hook
def encode( def encode(
self, x: torch.Tensor, return_dict: bool = True self, x: torch.Tensor, return_dict: bool = True
......
...@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DecoderTiny, EncoderTiny from .vae import AutoencoderMixin, DecoderOutput, DecoderTiny, EncoderTiny
@dataclass @dataclass
...@@ -38,7 +38,7 @@ class AutoencoderTinyOutput(BaseOutput): ...@@ -38,7 +38,7 @@ class AutoencoderTinyOutput(BaseOutput):
latents: torch.Tensor latents: torch.Tensor
class AutoencoderTiny(ModelMixin, ConfigMixin): class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
...@@ -162,35 +162,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -162,35 +162,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
"""[0, 1] -> raw latents""" """[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def enable_tiling(self, use_tiling: bool = True) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder. r"""Encode a batch of images using a tiled encoder.
......
...@@ -32,7 +32,7 @@ from ..attention_processor import ( ...@@ -32,7 +32,7 @@ from ..attention_processor import (
) )
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..unets.unet_2d import UNet2DModel from ..unets.unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass @dataclass
...@@ -49,7 +49,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput): ...@@ -49,7 +49,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution" latent_dist: "DiagonalGaussianDistribution"
class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
The consistency decoder used with DALL-E 3. The consistency decoder used with DALL-E 3.
...@@ -167,39 +167,6 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -167,39 +167,6 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25 self.tile_overlap_factor = 0.25
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property @property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
......
...@@ -894,3 +894,38 @@ class DecoderTiny(nn.Module): ...@@ -894,3 +894,38 @@ class DecoderTiny(nn.Module):
# scale image from [0, 1] to [-1, 1] to match diffusers convention # scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1) return x.mul(2).sub(1)
class AutoencoderMixin:
def enable_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
if not hasattr(self, "use_tiling"):
raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.")
self.use_tiling = True
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
if not hasattr(self, "use_slicing"):
raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.")
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
...@@ -22,6 +22,7 @@ from ...utils import BaseOutput ...@@ -22,6 +22,7 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin
@dataclass @dataclass
...@@ -37,7 +38,7 @@ class VQEncoderOutput(BaseOutput): ...@@ -37,7 +38,7 @@ class VQEncoderOutput(BaseOutput):
latents: torch.Tensor latents: torch.Tensor
class VQModel(ModelMixin, ConfigMixin): class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
r""" r"""
A VQ-VAE model for decoding latent representations. A VQ-VAE model for decoding latent representations.
......
...@@ -57,6 +57,9 @@ class AutoencoderTesterMixin: ...@@ -57,6 +57,9 @@ class AutoencoderTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
if not hasattr(model, "use_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
inputs_dict.update({"return_dict": False}) inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None) _ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model) accepts_generator = self._accepts_generator(model)
...@@ -102,6 +105,8 @@ class AutoencoderTesterMixin: ...@@ -102,6 +105,8 @@ class AutoencoderTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
if not hasattr(model, "use_slicing"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
inputs_dict.update({"return_dict": False}) inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None) _ = inputs_dict.pop("generator", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment