Unverified Commit 038b42db authored by Aryan V S's avatar Aryan V S Committed by GitHub
Browse files

Improve docs and type hints (#5759)



* improvement: docs and type hints

* improvement: docs and type hints

minor refactor

* improvement: docs and type hints

* update with suggestions from review
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent ecbe27a0
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from typing import List, Optional, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -126,14 +126,14 @@ class VaeImageProcessor(ConfigMixin): ...@@ -126,14 +126,14 @@ class VaeImageProcessor(ConfigMixin):
return images return images
@staticmethod @staticmethod
def normalize(images): def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
""" """
Normalize an image array to [-1,1]. Normalize an image array to [-1,1].
""" """
return 2.0 * images - 1.0 return 2.0 * images - 1.0
@staticmethod @staticmethod
def denormalize(images): def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
""" """
Denormalize an image array to [0,1]. Denormalize an image array to [0,1].
""" """
...@@ -159,10 +159,10 @@ class VaeImageProcessor(ConfigMixin): ...@@ -159,10 +159,10 @@ class VaeImageProcessor(ConfigMixin):
def get_default_height_width( def get_default_height_width(
self, self,
image: [PIL.Image.Image, np.ndarray, torch.Tensor], image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
): ) -> Tuple[int, int]:
""" """
This function return the height and width that are downscaled to the next integer multiple of This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`. `vae_scale_factor`.
...@@ -202,12 +202,24 @@ class VaeImageProcessor(ConfigMixin): ...@@ -202,12 +202,24 @@ class VaeImageProcessor(ConfigMixin):
def resize( def resize(
self, self,
image: [PIL.Image.Image, np.ndarray, torch.Tensor], image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
) -> [PIL.Image.Image, np.ndarray, torch.Tensor]: ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
""" """
Resize image. Resize image.
Args:
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor.
height (`int`, *optional*, defaults to `None`):
The height to resize to.
width (`int`, *optional*`, defaults to `None`):
The width to resize to.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The resized image.
""" """
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
...@@ -227,7 +239,15 @@ class VaeImageProcessor(ConfigMixin): ...@@ -227,7 +239,15 @@ class VaeImageProcessor(ConfigMixin):
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
""" """
create a mask Create a mask.
Args:
image (`PIL.Image.Image`):
The image input, should be a PIL image.
Returns:
`PIL.Image.Image`:
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
""" """
image[image < 0.5] = 0 image[image < 0.5] = 0
image[image >= 0.5] = 1 image[image >= 0.5] = 1
...@@ -327,7 +347,23 @@ class VaeImageProcessor(ConfigMixin): ...@@ -327,7 +347,23 @@ class VaeImageProcessor(ConfigMixin):
image: torch.FloatTensor, image: torch.FloatTensor,
output_type: str = "pil", output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None, do_denormalize: Optional[List[bool]] = None,
): ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
"""
Postprocess the image output from tensor to `output_type`.
Args:
image (`torch.FloatTensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
`VaeImageProcessor` config.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise ValueError( raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
...@@ -390,7 +426,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -390,7 +426,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
super().__init__() super().__init__()
@staticmethod @staticmethod
def numpy_to_pil(images): def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
""" """
Convert a NumPy image or a batch of images to a PIL image. Convert a NumPy image or a batch of images to a PIL image.
""" """
...@@ -406,7 +442,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -406,7 +442,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
return pil_images return pil_images
@staticmethod @staticmethod
def rgblike_to_depthmap(image): def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
""" """
Args: Args:
image: RGB-like depth image image: RGB-like depth image
...@@ -416,7 +452,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -416,7 +452,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
""" """
return image[:, :, 1] * 2**8 + image[:, :, 2] return image[:, :, 1] * 2**8 + image[:, :, 2]
def numpy_to_depth(self, images): def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
""" """
Convert a NumPy depth image or a batch of images to a PIL image. Convert a NumPy depth image or a batch of images to a PIL image.
""" """
...@@ -441,7 +477,23 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -441,7 +477,23 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
image: torch.FloatTensor, image: torch.FloatTensor,
output_type: str = "pil", output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None, do_denormalize: Optional[List[bool]] = None,
): ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
"""
Postprocess the image output from tensor to `output_type`.
Args:
image (`torch.FloatTensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
`VaeImageProcessor` config.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise ValueError( raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
......
...@@ -65,11 +65,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -65,11 +65,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",), down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
down_block_out_channels: Tuple[int] = (64,), down_block_out_channels: Tuple[int, ...] = (64,),
layers_per_down_block: int = 1, layers_per_down_block: int = 1,
up_block_types: Tuple[str] = ("UpDecoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
up_block_out_channels: Tuple[int] = (64,), up_block_out_channels: Tuple[int, ...] = (64,),
layers_per_up_block: int = 1, layers_per_up_block: int = 1,
act_fn: str = "silu", act_fn: str = "silu",
latent_channels: int = 4, latent_channels: int = 4,
...@@ -109,7 +109,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -109,7 +109,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.use_tiling = False self.use_tiling = False
@apply_forward_hook @apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
...@@ -125,7 +127,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -125,7 +127,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
image: Optional[torch.FloatTensor] = None, image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
z = self.post_quant_conv(z) z = self.post_quant_conv(z)
dec = self.decoder(z, image, mask) dec = self.decoder(z, image, mask)
...@@ -142,7 +144,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -142,7 +144,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
image: Optional[torch.FloatTensor] = None, image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
decoded = self._decode(z, image, mask).sample decoded = self._decode(z, image, mask).sample
if not return_dict: if not return_dict:
...@@ -157,7 +159,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -157,7 +159,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
sample_posterior: bool = False, sample_posterior: bool = False,
return_dict: bool = True, return_dict: bool = True,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.FloatTensor`): Input sample.
......
...@@ -322,13 +322,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -322,13 +322,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=decoded) return DecoderOutput(sample=decoded)
def blend_v(self, a, b, blend_extent): def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent) blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent): for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b return b
def blend_h(self, a, b, blend_extent): def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent) blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent): for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
......
...@@ -96,18 +96,18 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -96,18 +96,18 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
act_fn: str = "relu", act_fn: str = "relu",
latent_channels: int = 4, latent_channels: int = 4,
upsampling_scaling_factor: int = 2, upsampling_scaling_factor: int = 2,
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3), num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1), num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
latent_magnitude: int = 3, latent_magnitude: int = 3,
latent_shift: float = 0.5, latent_shift: float = 0.5,
force_upcast: float = False, force_upcast: bool = False,
scaling_factor: float = 1.0, scaling_factor: float = 1.0,
): ):
super().__init__() super().__init__()
...@@ -147,33 +147,33 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -147,33 +147,33 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self.tile_sample_min_size = 512 self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (EncoderTiny, DecoderTiny)): if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def scale_latents(self, x): def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""raw latents -> [0, 1]""" """raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
def unscale_latents(self, x): def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""[0, 1] -> raw latents""" """[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
def enable_slicing(self): def enable_slicing(self) -> None:
r""" r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
""" """
self.use_slicing = True self.use_slicing = True
def disable_slicing(self): def disable_slicing(self) -> None:
r""" r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step. decoding in one step.
""" """
self.use_slicing = False self.use_slicing = False
def enable_tiling(self, use_tiling: bool = True): def enable_tiling(self, use_tiling: bool = True) -> None:
r""" r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
...@@ -181,7 +181,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -181,7 +181,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
""" """
self.use_tiling = use_tiling self.use_tiling = use_tiling
def disable_tiling(self): def disable_tiling(self) -> None:
r""" r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step. decoding in one step.
...@@ -197,13 +197,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -197,13 +197,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
Returns: Returns:
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`: `torch.FloatTensor`: Encoded batch of images.
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
plain `tuple` is returned.
""" """
# scale of encoder output relative to input # scale of encoder output relative to input
sf = self.spatial_scale_factor sf = self.spatial_scale_factor
...@@ -249,13 +245,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -249,13 +245,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
Returns: Returns:
[`~models.vae.DecoderOutput`] or `tuple`: `torch.FloatTensor`: Encoded batch of images.
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
""" """
# scale of decoder output relative to input # scale of decoder output relative to input
sf = self.spatial_scale_factor sf = self.spatial_scale_factor
......
...@@ -70,39 +70,39 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -70,39 +70,39 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
scaling_factor=0.18215, scaling_factor: float = 0.18215,
latent_channels=4, latent_channels: int = 4,
encoder_act_fn="silu", encoder_act_fn: str = "silu",
encoder_block_out_channels=(128, 256, 512, 512), encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
encoder_double_z=True, encoder_double_z: bool = True,
encoder_down_block_types=( encoder_down_block_types: Tuple[str, ...] = (
"DownEncoderBlock2D", "DownEncoderBlock2D",
"DownEncoderBlock2D", "DownEncoderBlock2D",
"DownEncoderBlock2D", "DownEncoderBlock2D",
"DownEncoderBlock2D", "DownEncoderBlock2D",
), ),
encoder_in_channels=3, encoder_in_channels: int = 3,
encoder_layers_per_block=2, encoder_layers_per_block: int = 2,
encoder_norm_num_groups=32, encoder_norm_num_groups: int = 32,
encoder_out_channels=4, encoder_out_channels: int = 4,
decoder_add_attention=False, decoder_add_attention: bool = False,
decoder_block_out_channels=(320, 640, 1024, 1024), decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
decoder_down_block_types=( decoder_down_block_types: Tuple[str, ...] = (
"ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D",
), ),
decoder_downsample_padding=1, decoder_downsample_padding: int = 1,
decoder_in_channels=7, decoder_in_channels: int = 7,
decoder_layers_per_block=3, decoder_layers_per_block: int = 3,
decoder_norm_eps=1e-05, decoder_norm_eps: float = 1e-05,
decoder_norm_num_groups=32, decoder_norm_num_groups: int = 32,
decoder_num_train_timesteps=1024, decoder_num_train_timesteps: int = 1024,
decoder_out_channels=6, decoder_out_channels: int = 6,
decoder_resnet_time_scale_shift="scale_shift", decoder_resnet_time_scale_shift: str = "scale_shift",
decoder_time_embedding_type="learned", decoder_time_embedding_type: str = "learned",
decoder_up_block_types=( decoder_up_block_types: Tuple[str, ...] = (
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
...@@ -304,8 +304,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -304,8 +304,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
z: torch.FloatTensor, z: torch.FloatTensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
num_inference_steps=2, num_inference_steps: int = 2,
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
z = (z * self.config.scaling_factor - self.means) / self.stds z = (z * self.config.scaling_factor - self.means) / self.stds
scale_factor = 2 ** (len(self.config.block_out_channels) - 1) scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
...@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return DecoderOutput(sample=x_0) return DecoderOutput(sample=x_0)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
def blend_v(self, a, b, blend_extent): def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent) blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent): for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b return b
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
def blend_h(self, a, b, blend_extent): def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent) blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent): for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
...@@ -407,7 +407,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -407,7 +407,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
sample_posterior: bool = False, sample_posterior: bool = False,
return_dict: bool = True, return_dict: bool = True,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): Input sample. sample (`torch.FloatTensor`): Input sample.
...@@ -415,6 +415,12 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -415,6 +415,12 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
Whether to sample from the posterior. Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*, defaults to `None`):
Generator to use for sampling.
Returns:
[`DecoderOutput`] or `tuple`:
If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
""" """
x = sample x = sample
posterior = self.encode(x).latent_dist posterior = self.encode(x).latent_dist
......
...@@ -76,7 +76,7 @@ class ControlNetConditioningEmbedding(nn.Module): ...@@ -76,7 +76,7 @@ class ControlNetConditioningEmbedding(nn.Module):
self, self,
conditioning_embedding_channels: int, conditioning_embedding_channels: int,
conditioning_channels: int = 3, conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256), block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
): ):
super().__init__() super().__init__()
...@@ -171,6 +171,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -171,6 +171,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer. The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`): global_pool_conditions (`bool`, defaults to `False`):
TODO(Patrick) - unused parameter.
addition_embed_type_num_heads (`int`, defaults to 64):
The number of heads to use for the `TextTimeEmbedding` layer.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -182,14 +185,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -182,14 +185,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
conditioning_channels: int = 3, conditioning_channels: int = 3,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2, layers_per_block: int = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
...@@ -197,11 +200,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -197,11 +200,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None, encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False, use_linear_projection: bool = False,
class_embed_type: Optional[str] = None, class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None,
...@@ -211,9 +214,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -211,9 +214,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb", controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False, global_pool_conditions: bool = False,
addition_embed_type_num_heads=64, addition_embed_type_num_heads: int = 64,
): ):
super().__init__() super().__init__()
...@@ -426,7 +429,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -426,7 +429,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
cls, cls,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb", controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True, load_weights_from_unet: bool = True,
): ):
r""" r"""
...@@ -570,7 +573,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -570,7 +573,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -635,7 +638,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -635,7 +638,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -653,7 +656,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -653,7 +656,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False, guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]: ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
""" """
The [`ControlNetModel`] forward method. The [`ControlNetModel`] forward method.
......
...@@ -46,10 +46,10 @@ class FlaxControlNetOutput(BaseOutput): ...@@ -46,10 +46,10 @@ class FlaxControlNetOutput(BaseOutput):
class FlaxControlNetConditioningEmbedding(nn.Module): class FlaxControlNetConditioningEmbedding(nn.Module):
conditioning_embedding_channels: int conditioning_embedding_channels: int
block_out_channels: Tuple[int] = (16, 32, 96, 256) block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self) -> None:
self.conv_in = nn.Conv( self.conv_in = nn.Conv(
self.block_out_channels[0], self.block_out_channels[0],
kernel_size=(3, 3), kernel_size=(3, 3),
...@@ -87,7 +87,7 @@ class FlaxControlNetConditioningEmbedding(nn.Module): ...@@ -87,7 +87,7 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
dtype=self.dtype, dtype=self.dtype,
) )
def __call__(self, conditioning): def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
embedding = self.conv_in(conditioning) embedding = self.conv_in(conditioning)
embedding = nn.silu(embedding) embedding = nn.silu(embedding)
...@@ -148,17 +148,17 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -148,17 +148,17 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
""" """
sample_size: int = 32 sample_size: int = 32
in_channels: int = 4 in_channels: int = 4
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
) )
only_cross_attention: Union[bool, Tuple[bool]] = False only_cross_attention: Union[bool, Tuple[bool, ...]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280) block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2 layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8 attention_head_dim: Union[int, Tuple[int, ...]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
cross_attention_dim: int = 1280 cross_attention_dim: int = 1280
dropout: float = 0.0 dropout: float = 0.0
use_linear_projection: bool = False use_linear_projection: bool = False
...@@ -166,7 +166,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -166,7 +166,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True flip_sin_to_cos: bool = True
freq_shift: int = 0 freq_shift: int = 0
controlnet_conditioning_channel_order: str = "rgb" controlnet_conditioning_channel_order: str = "rgb"
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
def init_weights(self, rng: jax.Array) -> FrozenDict: def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors # init input tensors
...@@ -182,7 +182,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -182,7 +182,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
def setup(self): def setup(self) -> None:
block_out_channels = self.block_out_channels block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
...@@ -312,21 +312,21 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -312,21 +312,21 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
def __call__( def __call__(
self, self,
sample, sample: jnp.ndarray,
timesteps, timesteps: Union[jnp.ndarray, float, int],
encoder_hidden_states, encoder_hidden_states: jnp.ndarray,
controlnet_cond, controlnet_cond: jnp.ndarray,
conditioning_scale: float = 1.0, conditioning_scale: float = 1.0,
return_dict: bool = True, return_dict: bool = True,
train: bool = False, train: bool = False,
) -> Union[FlaxControlNetOutput, Tuple]: ) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
r""" r"""
Args: Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale: (`float`) the scale factor for controlnet outputs conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple. plain tuple.
...@@ -335,8 +335,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -335,8 +335,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
Returns: Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
channel_order = self.controlnet_conditioning_channel_order channel_order = self.controlnet_conditioning_channel_order
if channel_order == "bgr": if channel_order == "bgr":
......
...@@ -18,13 +18,14 @@ import inspect ...@@ -18,13 +18,14 @@ import inspect
import itertools import itertools
import os import os
import re import re
from collections import OrderedDict
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub import create_repo from huggingface_hub import create_repo
from torch import Tensor, device, nn from torch import Tensor, nn
from .. import __version__ from .. import __version__
from ..utils import ( from ..utils import (
...@@ -61,7 +62,7 @@ if is_accelerate_available(): ...@@ -61,7 +62,7 @@ if is_accelerate_available():
from accelerate.utils.versions import is_torch_version from accelerate.utils.versions import is_torch_version
def get_parameter_device(parameter: torch.nn.Module): def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
try: try:
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device return next(parameters_and_buffers).device
...@@ -77,7 +78,7 @@ def get_parameter_device(parameter: torch.nn.Module): ...@@ -77,7 +78,7 @@ def get_parameter_device(parameter: torch.nn.Module):
return first_tuple[1].device return first_tuple[1].device
def get_parameter_dtype(parameter: torch.nn.Module): def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try: try:
params = tuple(parameter.parameters()) params = tuple(parameter.parameters())
if len(params) > 0: if len(params) > 0:
...@@ -130,7 +131,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ ...@@ -130,7 +131,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
) )
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None): def load_model_dict_into_meta(
model,
state_dict: OrderedDict,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
) -> List[str]:
device = device or torch.device("cpu") device = device or torch.device("cpu")
dtype = dtype or torch.float32 dtype = dtype or torch.float32
...@@ -156,7 +163,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_ ...@@ -156,7 +163,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_
return unexpected_keys return unexpected_keys
def _load_state_dict_into_model(model_to_load, state_dict): def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy() state_dict = state_dict.copy()
...@@ -164,7 +171,7 @@ def _load_state_dict_into_model(model_to_load, state_dict): ...@@ -164,7 +171,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively. # so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix=""): def load(module: torch.nn.Module, prefix: str = ""):
args = (state_dict, prefix, {}, True, [], [], error_msgs) args = (state_dict, prefix, {}, True, [], [], error_msgs)
module._load_from_state_dict(*args) module._load_from_state_dict(*args)
...@@ -220,7 +227,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -220,7 +227,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
""" """
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def enable_gradient_checkpointing(self): def enable_gradient_checkpointing(self) -> None:
""" """
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks). *checkpoint activations* in other frameworks).
...@@ -229,7 +236,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -229,7 +236,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
def disable_gradient_checkpointing(self): def disable_gradient_checkpointing(self) -> None:
""" """
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks). *checkpoint activations* in other frameworks).
...@@ -254,7 +261,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -254,7 +261,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module) fn_recursive_set_mem_eff(module)
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
r""" r"""
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
...@@ -290,7 +297,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -290,7 +297,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
""" """
self.set_use_memory_efficient_attention_xformers(True, attention_op) self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self): def disable_xformers_memory_efficient_attention(self) -> None:
r""" r"""
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
""" """
...@@ -447,7 +454,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -447,7 +454,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
is_main_process: bool = True, is_main_process: bool = True,
save_function: Callable = None, save_function: Optional[Callable] = None,
safe_serialization: bool = True, safe_serialization: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
push_to_hub: bool = False, push_to_hub: bool = False,
...@@ -910,10 +917,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -910,10 +917,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def _load_pretrained_model( def _load_pretrained_model(
cls, cls,
model, model,
state_dict, state_dict: OrderedDict,
resolved_archive_file, resolved_archive_file,
pretrained_model_name_or_path, pretrained_model_name_or_path: Union[str, os.PathLike],
ignore_mismatched_sizes=False, ignore_mismatched_sizes: bool = False,
): ):
# Retrieve missing & unexpected_keys # Retrieve missing & unexpected_keys
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
...@@ -1011,7 +1018,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1011,7 +1018,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
@property @property
def device(self) -> device: def device(self) -> torch.device:
""" """
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device). device).
...@@ -1063,7 +1070,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1063,7 +1070,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else: else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def _convert_deprecated_attention_blocks(self, state_dict): def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
deprecated_attention_block_paths = [] deprecated_attention_block_paths = []
def recursive_find_attn_block(name, module): def recursive_find_attn_block(name, module):
...@@ -1107,7 +1114,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1107,7 +1114,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if f"{path}.proj_attn.bias" in state_dict: if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
def _temp_convert_self_to_deprecated_attention_blocks(self): def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
deprecated_attention_block_modules = [] deprecated_attention_block_modules = []
def recursive_find_attn_block(module): def recursive_find_attn_block(module):
...@@ -1134,10 +1141,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1134,10 +1141,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
del module.to_v del module.to_v
del module.to_out del module.to_out
def _undo_temp_convert_self_to_deprecated_attention_blocks(self): def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
deprecated_attention_block_modules = [] deprecated_attention_block_modules = []
def recursive_find_attn_block(module): def recursive_find_attn_block(module) -> None:
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module) deprecated_attention_block_modules.append(module)
......
...@@ -101,8 +101,8 @@ class AdaLayerNormSingle(nn.Module): ...@@ -101,8 +101,8 @@ class AdaLayerNormSingle(nn.Module):
def forward( def forward(
self, self,
timestep: torch.Tensor, timestep: torch.Tensor,
added_cond_kwargs: Dict[str, torch.Tensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: int = None, batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None, hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here. # No modulation happening here.
......
...@@ -164,7 +164,9 @@ class Upsample2D(nn.Module): ...@@ -164,7 +164,9 @@ class Upsample2D(nn.Module):
else: else:
self.Conv2d_0 = conv self.Conv2d_0 = conv
def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0): def forward(
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose: if self.use_conv_transpose:
...@@ -256,7 +258,7 @@ class Downsample2D(nn.Module): ...@@ -256,7 +258,7 @@ class Downsample2D(nn.Module):
else: else:
self.conv = conv self.conv = conv
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
...@@ -280,7 +282,7 @@ class FirUpsample2D(nn.Module): ...@@ -280,7 +282,7 @@ class FirUpsample2D(nn.Module):
"""A 2D FIR upsampling layer with an optional convolution. """A 2D FIR upsampling layer with an optional convolution.
Parameters: Parameters:
channels (`int`): channels (`int`, optional):
number of channels in the inputs and outputs. number of channels in the inputs and outputs.
use_conv (`bool`, default `False`): use_conv (`bool`, default `False`):
option to use a convolution. option to use a convolution.
...@@ -292,7 +294,7 @@ class FirUpsample2D(nn.Module): ...@@ -292,7 +294,7 @@ class FirUpsample2D(nn.Module):
def __init__( def __init__(
self, self,
channels: int = None, channels: Optional[int] = None,
out_channels: Optional[int] = None, out_channels: Optional[int] = None,
use_conv: bool = False, use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
...@@ -307,12 +309,12 @@ class FirUpsample2D(nn.Module): ...@@ -307,12 +309,12 @@ class FirUpsample2D(nn.Module):
def _upsample_2d( def _upsample_2d(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
weight: Optional[torch.Tensor] = None, weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None, kernel: Optional[torch.FloatTensor] = None,
factor: int = 2, factor: int = 2,
gain: float = 1, gain: float = 1,
) -> torch.Tensor: ) -> torch.FloatTensor:
"""Fused `upsample_2d()` followed by `Conv2d()`. """Fused `upsample_2d()` followed by `Conv2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
...@@ -320,17 +322,21 @@ class FirUpsample2D(nn.Module): ...@@ -320,17 +322,21 @@ class FirUpsample2D(nn.Module):
arbitrary order. arbitrary order.
Args: Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. hidden_states (`torch.FloatTensor`):
weight: Weight tensor of the shape `[filterH, filterW, inChannels, Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. weight (`torch.FloatTensor`, *optional*):
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. performed by `inChannels = x.shape[0] // numGroups`.
factor: Integer upsampling factor (default: 2). kernel (`torch.FloatTensor`, *optional*):
gain: Scaling factor for signal magnitude (default: 1.0). FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to nearest-neighbor upsampling.
factor (`int`, *optional*): Integer upsampling factor (default: 2).
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
Returns: Returns:
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same output (`torch.FloatTensor`):
datatype as `hidden_states`. Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as `hidden_states`.
""" """
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
...@@ -392,7 +398,7 @@ class FirUpsample2D(nn.Module): ...@@ -392,7 +398,7 @@ class FirUpsample2D(nn.Module):
return output return output
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv: if self.use_conv:
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
...@@ -418,7 +424,7 @@ class FirDownsample2D(nn.Module): ...@@ -418,7 +424,7 @@ class FirDownsample2D(nn.Module):
def __init__( def __init__(
self, self,
channels: int = None, channels: Optional[int] = None,
out_channels: Optional[int] = None, out_channels: Optional[int] = None,
use_conv: bool = False, use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
...@@ -433,30 +439,35 @@ class FirDownsample2D(nn.Module): ...@@ -433,30 +439,35 @@ class FirDownsample2D(nn.Module):
def _downsample_2d( def _downsample_2d(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
weight: Optional[torch.Tensor] = None, weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None, kernel: Optional[torch.FloatTensor] = None,
factor: int = 2, factor: int = 2,
gain: float = 1, gain: float = 1,
) -> torch.Tensor: ) -> torch.FloatTensor:
"""Fused `Conv2d()` followed by `downsample_2d()`. """Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order. arbitrary order.
Args: Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. hidden_states (`torch.FloatTensor`):
weight: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`. performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * kernel (`torch.FloatTensor`, *optional*):
factor`, which corresponds to average pooling. FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
factor: Integer downsampling factor (default: 2). corresponds to average pooling.
gain: Scaling factor for signal magnitude (default: 1.0). factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns: Returns:
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and output (`torch.FloatTensor`):
same datatype as `x`. Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
""" """
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
...@@ -492,7 +503,7 @@ class FirDownsample2D(nn.Module): ...@@ -492,7 +503,7 @@ class FirDownsample2D(nn.Module):
return output return output
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv: if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
...@@ -682,7 +693,9 @@ class ResnetBlock2D(nn.Module): ...@@ -682,7 +693,9 @@ class ResnetBlock2D(nn.Module):
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
) )
def forward(self, input_tensor, temb, scale: float = 1.0): def forward(
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
) -> torch.FloatTensor:
hidden_states = input_tensor hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
...@@ -778,7 +791,7 @@ class Conv1dBlock(nn.Module): ...@@ -778,7 +791,7 @@ class Conv1dBlock(nn.Module):
out_channels (`int`): Number of output channels. out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel. kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into. n_groups (`int`, default `8`): Number of groups to separate the channels into.
activation (`str`, defaults `mish`): Name of the activation function. activation (`str`, defaults to `mish`): Name of the activation function.
""" """
def __init__( def __init__(
...@@ -853,8 +866,8 @@ class ResidualTemporalBlock1D(nn.Module): ...@@ -853,8 +866,8 @@ class ResidualTemporalBlock1D(nn.Module):
def upsample_2d( def upsample_2d(
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor: ) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter. r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
...@@ -862,14 +875,19 @@ def upsample_2d( ...@@ -862,14 +875,19 @@ def upsample_2d(
a: multiple of the upsampling factor. a: multiple of the upsampling factor.
Args: Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. hidden_states (`torch.FloatTensor`):
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. kernel (`torch.FloatTensor`, *optional*):
factor: Integer upsampling factor (default: 2). FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
gain: Scaling factor for signal magnitude (default: 1.0). corresponds to nearest-neighbor upsampling.
factor (`int`, *optional*, default to `2`):
Integer upsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude (default: 1.0).
Returns: Returns:
output: Tensor of the shape `[N, C, H * factor, W * factor]` output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H * factor, W * factor]`
""" """
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
if kernel is None: if kernel is None:
...@@ -892,8 +910,8 @@ def upsample_2d( ...@@ -892,8 +910,8 @@ def upsample_2d(
def downsample_2d( def downsample_2d(
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1 hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
) -> torch.Tensor: ) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter. r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
...@@ -901,14 +919,19 @@ def downsample_2d( ...@@ -901,14 +919,19 @@ def downsample_2d(
shape is a multiple of the downsampling factor. shape is a multiple of the downsampling factor.
Args: Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. hidden_states (`torch.FloatTensor`)
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
(separable). The default is `[1] * factor`, which corresponds to average pooling. kernel (`torch.FloatTensor`, *optional*):
factor: Integer downsampling factor (default: 2). FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
gain: Scaling factor for signal magnitude (default: 1.0). corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns: Returns:
output: Tensor of the shape `[N, C, H // factor, W // factor]` output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]`
""" """
assert isinstance(factor, int) and factor >= 1 assert isinstance(factor, int) and factor >= 1
......
...@@ -100,18 +100,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -100,18 +100,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample_size: int = 32 sample_size: int = 32
in_channels: int = 4 in_channels: int = 4
out_channels: int = 4 out_channels: int = 4
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
) )
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280) block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2 layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8 attention_head_dim: Union[int, Tuple[int, ...]] = 8
num_attention_heads: Optional[Union[int, Tuple[int]]] = None num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
cross_attention_dim: int = 1280 cross_attention_dim: int = 1280
dropout: float = 0.0 dropout: float = 0.0
use_linear_projection: bool = False use_linear_projection: bool = False
...@@ -120,7 +120,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -120,7 +120,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
freq_shift: int = 0 freq_shift: int = 0
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
split_head_dim: bool = False split_head_dim: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1 transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
addition_embed_type: Optional[str] = None addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64 addition_embed_type_num_heads: int = 64
...@@ -158,7 +158,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -158,7 +158,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
} }
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
def setup(self): def setup(self) -> None:
block_out_channels = self.block_out_channels block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
...@@ -320,15 +320,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -320,15 +320,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
def __call__( def __call__(
self, self,
sample, sample: jnp.ndarray,
timesteps, timesteps: Union[jnp.ndarray, float, int],
encoder_hidden_states, encoder_hidden_states: jnp.ndarray,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None, added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
down_block_additional_residuals=None, down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
mid_block_additional_residual=None, mid_block_additional_residual: Optional[jnp.ndarray] = None,
return_dict: bool = True, return_dict: bool = True,
train: bool = False, train: bool = False,
) -> Union[FlaxUNet2DConditionOutput, Tuple]: ) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]:
r""" r"""
Args: Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -26,26 +26,26 @@ from .transformer_temporal import TransformerTemporalModel ...@@ -26,26 +26,26 @@ from .transformer_temporal import TransformerTemporalModel
def get_down_block( def get_down_block(
down_block_type, down_block_type: str,
num_layers, num_layers: int,
in_channels, in_channels: int,
out_channels, out_channels: int,
temb_channels, temb_channels: int,
add_downsample, add_downsample: bool,
resnet_eps, resnet_eps: float,
resnet_act_fn, resnet_act_fn: str,
num_attention_heads, num_attention_heads: int,
resnet_groups=None, resnet_groups: Optional[int] = None,
cross_attention_dim=None, cross_attention_dim: Optional[int] = None,
downsample_padding=None, downsample_padding: Optional[int] = None,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=True, use_linear_projection: bool = True,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
resnet_time_scale_shift="default", resnet_time_scale_shift: str = "default",
temporal_num_attention_heads=8, temporal_num_attention_heads: int = 8,
temporal_max_seq_length=32, temporal_max_seq_length: int = 32,
): ) -> Union["DownBlock3D", "CrossAttnDownBlock3D", "DownBlockMotion", "CrossAttnDownBlockMotion"]:
if down_block_type == "DownBlock3D": if down_block_type == "DownBlock3D":
return DownBlock3D( return DownBlock3D(
num_layers=num_layers, num_layers=num_layers,
...@@ -123,28 +123,28 @@ def get_down_block( ...@@ -123,28 +123,28 @@ def get_down_block(
def get_up_block( def get_up_block(
up_block_type, up_block_type: str,
num_layers, num_layers: int,
in_channels, in_channels: int,
out_channels, out_channels: int,
prev_output_channel, prev_output_channel: int,
temb_channels, temb_channels: int,
add_upsample, add_upsample: bool,
resnet_eps, resnet_eps: float,
resnet_act_fn, resnet_act_fn: str,
num_attention_heads, num_attention_heads: int,
resolution_idx=None, resolution_idx: Optional[int] = None,
resnet_groups=None, resnet_groups: Optional[int] = None,
cross_attention_dim=None, cross_attention_dim: Optional[int] = None,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=True, use_linear_projection: bool = True,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
resnet_time_scale_shift="default", resnet_time_scale_shift: str = "default",
temporal_num_attention_heads=8, temporal_num_attention_heads: int = 8,
temporal_cross_attention_dim=None, temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length=32, temporal_max_seq_length: int = 32,
): ) -> Union["UpBlock3D", "CrossAttnUpBlock3D", "UpBlockMotion", "CrossAttnUpBlockMotion"]:
if up_block_type == "UpBlock3D": if up_block_type == "UpBlock3D":
return UpBlock3D( return UpBlock3D(
num_layers=num_layers, num_layers=num_layers,
...@@ -236,12 +236,12 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -236,12 +236,12 @@ class UNetMidBlock3DCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=True, use_linear_projection: bool = True,
upcast_attention=False, upcast_attention: bool = False,
): ):
super().__init__() super().__init__()
...@@ -328,13 +328,13 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -328,13 +328,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
temb=None, temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1, num_frames: int = 1,
cross_attention_kwargs=None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
for attn, temp_attn, resnet, temp_conv in zip( for attn, temp_attn, resnet, temp_conv in zip(
...@@ -368,15 +368,15 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -368,15 +368,15 @@ class CrossAttnDownBlock3D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
downsample_padding=1, downsample_padding: int = 1,
add_downsample=True, add_downsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -454,13 +454,13 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -454,13 +454,13 @@ class CrossAttnDownBlock3D(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
temb=None, temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1, num_frames: int = 1,
cross_attention_kwargs=None, cross_attention_kwargs: Dict[str, Any] = None,
): ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
# TODO(Patrick, William) - attention mask is not used # TODO(Patrick, William) - attention mask is not used
output_states = () output_states = ()
...@@ -503,9 +503,9 @@ class DownBlock3D(nn.Module): ...@@ -503,9 +503,9 @@ class DownBlock3D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
downsample_padding=1, downsample_padding: int = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -552,7 +552,9 @@ class DownBlock3D(nn.Module): ...@@ -552,7 +552,9 @@ class DownBlock3D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, num_frames=1): def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, num_frames: int = 1
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
for resnet, temp_conv in zip(self.resnets, self.temp_convs): for resnet, temp_conv in zip(self.resnets, self.temp_convs):
...@@ -584,15 +586,15 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -584,15 +586,15 @@ class CrossAttnUpBlock3D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
resolution_idx=None, resolution_idx: Optional[int] = None,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -667,15 +669,15 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -667,15 +669,15 @@ class CrossAttnUpBlock3D(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
res_hidden_states_tuple, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb=None, temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
upsample_size=None, upsample_size: Optional[int] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1, num_frames: int = 1,
cross_attention_kwargs=None, cross_attention_kwargs: Dict[str, Any] = None,
): ) -> torch.FloatTensor:
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -738,9 +740,9 @@ class UpBlock3D(nn.Module): ...@@ -738,9 +740,9 @@ class UpBlock3D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
resolution_idx=None, resolution_idx: Optional[int] = None,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -784,7 +786,14 @@ class UpBlock3D(nn.Module): ...@@ -784,7 +786,14 @@ class UpBlock3D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
num_frames: int = 1,
) -> torch.FloatTensor:
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -833,12 +842,12 @@ class DownBlockMotion(nn.Module): ...@@ -833,12 +842,12 @@ class DownBlockMotion(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
downsample_padding=1, downsample_padding: int = 1,
temporal_num_attention_heads=1, temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim=None, temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length=32, temporal_max_seq_length: int = 32,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -890,7 +899,13 @@ class DownBlockMotion(nn.Module): ...@@ -890,7 +899,13 @@ class DownBlockMotion(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1): def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
num_frames: int = 1,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
blocks = zip(self.resnets, self.motion_modules) blocks = zip(self.resnets, self.motion_modules)
...@@ -944,19 +959,19 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -944,19 +959,19 @@ class CrossAttnDownBlockMotion(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
downsample_padding=1, downsample_padding: int = 1,
add_downsample=True, add_downsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
temporal_cross_attention_dim=None, temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads=8, temporal_num_attention_heads: int = 8,
temporal_max_seq_length=32, temporal_max_seq_length: int = 32,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1043,14 +1058,14 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1043,14 +1058,14 @@ class CrossAttnDownBlockMotion(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
temb=None, temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1, num_frames: int = 1,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs=None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals=None, additional_residuals: Optional[torch.FloatTensor] = None,
): ):
output_states = () output_states = ()
...@@ -1121,7 +1136,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1121,7 +1136,7 @@ class CrossAttnUpBlockMotion(nn.Module):
out_channels: int, out_channels: int,
prev_output_channel: int, prev_output_channel: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: int = 1, transformer_layers_per_block: int = 1,
...@@ -1130,18 +1145,18 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1130,18 +1145,18 @@ class CrossAttnUpBlockMotion(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
temporal_cross_attention_dim=None, temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads=8, temporal_num_attention_heads: int = 8,
temporal_max_seq_length=32, temporal_max_seq_length: int = 32,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1232,8 +1247,8 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1232,8 +1247,8 @@ class CrossAttnUpBlockMotion(nn.Module):
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1, num_frames: int = 1,
): ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
...@@ -1317,7 +1332,7 @@ class UpBlockMotion(nn.Module): ...@@ -1317,7 +1332,7 @@ class UpBlockMotion(nn.Module):
prev_output_channel: int, prev_output_channel: int,
out_channels: int, out_channels: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
...@@ -1325,12 +1340,12 @@ class UpBlockMotion(nn.Module): ...@@ -1325,12 +1340,12 @@ class UpBlockMotion(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
temporal_norm_num_groups=32, temporal_norm_num_groups: int = 32,
temporal_cross_attention_dim=None, temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads=8, temporal_num_attention_heads: int = 8,
temporal_max_seq_length=32, temporal_max_seq_length: int = 32,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1381,8 +1396,14 @@ class UpBlockMotion(nn.Module): ...@@ -1381,8 +1396,14 @@ class UpBlockMotion(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward( def forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1 self,
): hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size=None,
scale: float = 1.0,
num_frames: int = 1,
) -> torch.FloatTensor:
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -1457,16 +1478,16 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1457,16 +1478,16 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
dual_cross_attention=False, dual_cross_attention: float = False,
use_linear_projection=False, use_linear_projection: float = False,
upcast_attention=False, upcast_attention: float = False,
attention_type="default", attention_type: str = "default",
temporal_num_attention_heads=1, temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim=None, temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length=32, temporal_max_seq_length: int = 32,
): ):
super().__init__() super().__init__()
...@@ -1560,7 +1581,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1560,7 +1581,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1, num_frames: int = 1,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
......
...@@ -98,14 +98,19 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -98,14 +98,19 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample_size: Optional[int] = None, sample_size: Optional[int] = None,
in_channels: int = 4, in_channels: int = 4,
out_channels: int = 4, out_channels: int = 4,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock3D", "CrossAttnDownBlock3D",
"CrossAttnDownBlock3D", "CrossAttnDownBlock3D",
"CrossAttnDownBlock3D", "CrossAttnDownBlock3D",
"DownBlock3D", "DownBlock3D",
), ),
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), up_block_types: Tuple[str, ...] = (
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), "UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2, layers_per_block: int = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
...@@ -302,7 +307,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -302,7 +307,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -404,7 +409,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -404,7 +409,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
def enable_forward_chunking(self, chunk_size=None, dim=0): def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
""" """
Sets the attention processor to use [feed forward Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
...@@ -460,7 +465,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -460,7 +465,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -510,7 +515,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -510,7 +515,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None, mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]: ) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]:
r""" r"""
The [`UNet3DConditionModel`] forward method. The [`UNet3DConditionModel`] forward method.
......
...@@ -50,14 +50,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -50,14 +50,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class MotionModules(nn.Module): class MotionModules(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
layers_per_block=2, layers_per_block: int = 2,
num_attention_heads=8, num_attention_heads: int = 8,
attention_bias=False, attention_bias: bool = False,
cross_attention_dim=None, cross_attention_dim: Optional[int] = None,
activation_fn="geglu", activation_fn: str = "geglu",
norm_num_groups=32, norm_num_groups: int = 32,
max_seq_length=32, max_seq_length: int = 32,
): ):
super().__init__() super().__init__()
self.motion_modules = nn.ModuleList([]) self.motion_modules = nn.ModuleList([])
...@@ -82,13 +82,13 @@ class MotionAdapter(ModelMixin, ConfigMixin): ...@@ -82,13 +82,13 @@ class MotionAdapter(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
block_out_channels=(320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
motion_layers_per_block=2, motion_layers_per_block: int = 2,
motion_mid_block_layers_per_block=1, motion_mid_block_layers_per_block: int = 1,
motion_num_attention_heads=8, motion_num_attention_heads: int = 8,
motion_norm_num_groups=32, motion_norm_num_groups: int = 32,
motion_max_seq_length=32, motion_max_seq_length: int = 32,
use_motion_mid_block=True, use_motion_mid_block: bool = True,
): ):
"""Container to store AnimateDiff Motion Modules """Container to store AnimateDiff Motion Modules
...@@ -182,29 +182,29 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -182,29 +182,29 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
sample_size: Optional[int] = None, sample_size: Optional[int] = None,
in_channels: int = 4, in_channels: int = 4,
out_channels: int = 4, out_channels: int = 4,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion",
"DownBlockMotion", "DownBlockMotion",
), ),
up_block_types: Tuple[str] = ( up_block_types: Tuple[str, ...] = (
"UpBlockMotion", "UpBlockMotion",
"CrossAttnUpBlockMotion", "CrossAttnUpBlockMotion",
"CrossAttnUpBlockMotion", "CrossAttnUpBlockMotion",
"CrossAttnUpBlockMotion", "CrossAttnUpBlockMotion",
), ),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2, layers_per_block: int = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
act_fn: str = "silu", act_fn: str = "silu",
norm_num_groups: Optional[int] = 32, norm_num_groups: int = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
use_linear_projection: bool = False, use_linear_projection: bool = False,
num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: Optional[int] = 32, motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8, motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True, use_motion_mid_block: int = True,
): ):
...@@ -448,7 +448,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -448,7 +448,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return model return model
def freeze_unet2d_params(self): def freeze_unet2d_params(self) -> None:
"""Freeze the weights of just the UNet2DConditionModel, and leave the motion modules """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
unfrozen for fine tuning. unfrozen for fine tuning.
""" """
...@@ -472,9 +472,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -472,9 +472,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for param in motion_modules.parameters(): for param in motion_modules.parameters():
param.requires_grad = True param.requires_grad = True
return def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]):
for i, down_block in enumerate(motion_adapter.down_blocks): for i, down_block in enumerate(motion_adapter.down_blocks):
self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
for i, up_block in enumerate(motion_adapter.up_blocks): for i, up_block in enumerate(motion_adapter.up_blocks):
...@@ -492,7 +490,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -492,7 +490,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
variant: Optional[str] = None, variant: Optional[str] = None,
push_to_hub: bool = False, push_to_hub: bool = False,
**kwargs, **kwargs,
): ) -> None:
state_dict = self.state_dict() state_dict = self.state_dict()
# Extract all motion modules # Extract all motion modules
...@@ -582,7 +580,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -582,7 +580,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size=None, dim=0): def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
""" """
Sets the attention processor to use [feed forward Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
...@@ -612,7 +610,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -612,7 +610,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
fn_recursive_feed_forward(module, chunk_size, dim) fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self): def disable_forward_chunking(self) -> None:
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"): if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
...@@ -624,7 +622,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -624,7 +622,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
fn_recursive_feed_forward(module, None, 0) fn_recursive_feed_forward(module, None, 0)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self) -> None:
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
""" """
...@@ -639,12 +637,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -639,12 +637,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
self.set_attn_processor(processor, _remove_lora=True) self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2): def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied. The suffixes after the scaling factors represent the stage blocks where they are being applied.
...@@ -669,7 +667,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -669,7 +667,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
setattr(upsample_block, "b2", b2) setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self): def disable_freeu(self) -> None:
"""Disables the FreeU mechanism.""" """Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"} freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
...@@ -688,7 +686,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -688,7 +686,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None, mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]: ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
r""" r"""
The [`UNetMotionModel`] forward method. The [`UNetMotionModel`] forward method.
......
...@@ -148,7 +148,9 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -148,7 +148,9 @@ class VQModel(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def forward(
self, sample: torch.FloatTensor, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.FloatTensor, ...]]:
r""" r"""
The [`VQModel`] forward method. The [`VQModel`] forward method.
......
...@@ -37,7 +37,7 @@ class SchedulerType(Enum): ...@@ -37,7 +37,7 @@ class SchedulerType(Enum):
PIECEWISE_CONSTANT = "piecewise_constant" PIECEWISE_CONSTANT = "piecewise_constant"
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR:
""" """
Create a schedule with a constant learning rate, using the learning rate set in optimizer. Create a schedule with a constant learning rate, using the learning rate set in optimizer.
...@@ -53,7 +53,7 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): ...@@ -53,7 +53,7 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR:
""" """
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
increases linearly between 0 and the initial lr set in the optimizer. increases linearly between 0 and the initial lr set in the optimizer.
...@@ -78,7 +78,7 @@ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: in ...@@ -78,7 +78,7 @@ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: in
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR:
""" """
Create a schedule with a constant learning rate, using the learning rate set in optimizer. Create a schedule with a constant learning rate, using the learning rate set in optimizer.
...@@ -120,7 +120,9 @@ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_ ...@@ -120,7 +120,9 @@ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): def get_linear_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1
) -> LambdaLR:
""" """
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
...@@ -151,7 +153,7 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st ...@@ -151,7 +153,7 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
def get_cosine_schedule_with_warmup( def get_cosine_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
): ) -> LambdaLR:
""" """
Create a schedule with a learning rate that decreases following the values of the cosine function between the Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
...@@ -185,7 +187,7 @@ def get_cosine_schedule_with_warmup( ...@@ -185,7 +187,7 @@ def get_cosine_schedule_with_warmup(
def get_cosine_with_hard_restarts_schedule_with_warmup( def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
): ) -> LambdaLR:
""" """
Create a schedule with a learning rate that decreases following the values of the cosine function between the Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
...@@ -219,8 +221,13 @@ def get_cosine_with_hard_restarts_schedule_with_warmup( ...@@ -219,8 +221,13 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(
def get_polynomial_decay_schedule_with_warmup( def get_polynomial_decay_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 optimizer: Optimizer,
): num_warmup_steps: int,
num_training_steps: int,
lr_end: float = 1e-7,
power: float = 1.0,
last_epoch: int = -1,
) -> LambdaLR:
""" """
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
...@@ -288,7 +295,7 @@ def get_scheduler( ...@@ -288,7 +295,7 @@ def get_scheduler(
num_cycles: int = 1, num_cycles: int = 1,
power: float = 1.0, power: float = 1.0,
last_epoch: int = -1, last_epoch: int = -1,
): ) -> LambdaLR:
""" """
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
......
...@@ -28,7 +28,7 @@ from logging import ( ...@@ -28,7 +28,7 @@ from logging import (
WARN, # NOQA WARN, # NOQA
WARNING, # NOQA WARNING, # NOQA
) )
from typing import Optional from typing import Dict, Optional
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
...@@ -49,7 +49,7 @@ _default_log_level = logging.WARNING ...@@ -49,7 +49,7 @@ _default_log_level = logging.WARNING
_tqdm_active = True _tqdm_active = True
def _get_default_logging_level(): def _get_default_logging_level() -> int:
""" """
If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level` not - fall back to `_default_log_level`
...@@ -104,7 +104,7 @@ def _reset_library_root_logger() -> None: ...@@ -104,7 +104,7 @@ def _reset_library_root_logger() -> None:
_default_handler = None _default_handler = None
def get_log_levels_dict(): def get_log_levels_dict() -> Dict[str, int]:
return log_levels return log_levels
...@@ -161,22 +161,22 @@ def set_verbosity(verbosity: int) -> None: ...@@ -161,22 +161,22 @@ def set_verbosity(verbosity: int) -> None:
_get_library_root_logger().setLevel(verbosity) _get_library_root_logger().setLevel(verbosity)
def set_verbosity_info(): def set_verbosity_info() -> None:
"""Set the verbosity to the `INFO` level.""" """Set the verbosity to the `INFO` level."""
return set_verbosity(INFO) return set_verbosity(INFO)
def set_verbosity_warning(): def set_verbosity_warning() -> None:
"""Set the verbosity to the `WARNING` level.""" """Set the verbosity to the `WARNING` level."""
return set_verbosity(WARNING) return set_verbosity(WARNING)
def set_verbosity_debug(): def set_verbosity_debug() -> None:
"""Set the verbosity to the `DEBUG` level.""" """Set the verbosity to the `DEBUG` level."""
return set_verbosity(DEBUG) return set_verbosity(DEBUG)
def set_verbosity_error(): def set_verbosity_error() -> None:
"""Set the verbosity to the `ERROR` level.""" """Set the verbosity to the `ERROR` level."""
return set_verbosity(ERROR) return set_verbosity(ERROR)
...@@ -263,7 +263,7 @@ def reset_format() -> None: ...@@ -263,7 +263,7 @@ def reset_format() -> None:
handler.setFormatter(None) handler.setFormatter(None)
def warning_advice(self, *args, **kwargs): def warning_advice(self, *args, **kwargs) -> None:
""" """
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed warning will not be printed
...@@ -327,13 +327,13 @@ def is_progress_bar_enabled() -> bool: ...@@ -327,13 +327,13 @@ def is_progress_bar_enabled() -> bool:
return bool(_tqdm_active) return bool(_tqdm_active)
def enable_progress_bar(): def enable_progress_bar() -> None:
"""Enable tqdm progress bar.""" """Enable tqdm progress bar."""
global _tqdm_active global _tqdm_active
_tqdm_active = True _tqdm_active = True
def disable_progress_bar(): def disable_progress_bar() -> None:
"""Disable tqdm progress bar.""" """Disable tqdm progress bar."""
global _tqdm_active global _tqdm_active
_tqdm_active = False _tqdm_active = False
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
from .import_utils import is_torch_available from .import_utils import is_torch_available
def is_tensor(x): def is_tensor(x) -> bool:
""" """
Tests if `x` is a `torch.Tensor` or `np.ndarray`. Tests if `x` is a `torch.Tensor` or `np.ndarray`.
""" """
...@@ -66,7 +66,7 @@ class BaseOutput(OrderedDict): ...@@ -66,7 +66,7 @@ class BaseOutput(OrderedDict):
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
) )
def __post_init__(self): def __post_init__(self) -> None:
class_fields = fields(self) class_fields = fields(self)
# Safety and consistency checks # Safety and consistency checks
...@@ -97,14 +97,14 @@ class BaseOutput(OrderedDict): ...@@ -97,14 +97,14 @@ class BaseOutput(OrderedDict):
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k): def __getitem__(self, k: Any) -> Any:
if isinstance(k, str): if isinstance(k, str):
inner_dict = dict(self.items()) inner_dict = dict(self.items())
return inner_dict[k] return inner_dict[k]
else: else:
return self.to_tuple()[k] return self.to_tuple()[k]
def __setattr__(self, name, value): def __setattr__(self, name: Any, value: Any) -> None:
if name in self.keys() and value is not None: if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors # Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value) super().__setitem__(name, value)
...@@ -123,7 +123,7 @@ class BaseOutput(OrderedDict): ...@@ -123,7 +123,7 @@ class BaseOutput(OrderedDict):
args = tuple(getattr(self, field.name) for field in fields(self)) args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining return callable, args, *remaining
def to_tuple(self) -> Tuple[Any]: def to_tuple(self) -> Tuple[Any, ...]:
""" """
Convert self to a tuple containing all the attributes/keys that are not `None`. Convert self to a tuple containing all the attributes/keys that are not `None`.
""" """
......
...@@ -82,14 +82,14 @@ def randn_tensor( ...@@ -82,14 +82,14 @@ def randn_tensor(
return latents return latents
def is_compiled_module(module): def is_compiled_module(module) -> bool:
"""Check whether the module was compiled with torch.compile()""" """Check whether the module was compiled with torch.compile()"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def fourier_filter(x_in, threshold, scale): def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor:
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
This version of the method comes from here: This version of the method comes from here:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment