Unverified Commit 1f0705ad authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Big refactor] move unets to `unets` module 🦋 (#6630)

* move unets to  module 🦋

* parameterize unet-level import.

* fix flax unet2dcondition model import

* models __init__

* mildly depcrecating models.unet_2d_blocks in favor of models.unets.unet_2d_blocks.

* noqa

* correct depcrecation behaviour

* inherit from the actual classes.

* Empty-Commit

* backwards compatibility for unet_2d.py

* backward compatibility for unet_2d_condition

* bc for unet_1d

* bc for unet_1d_blocks
parent 5e96333c
...@@ -19,11 +19,11 @@ import torch ...@@ -19,11 +19,11 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ...utils import BaseOutput, logging
from .attention_processor import Attention, AttentionProcessor, AttnProcessor from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -17,19 +17,19 @@ import torch ...@@ -17,19 +17,19 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ..utils import logging from ...utils import logging
from .attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from .embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel from ..transformer_temporal import TransformerTemporalModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_blocks import ( from .unet_3d_blocks import (
...@@ -524,7 +524,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -524,7 +524,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
) )
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
...@@ -548,7 +548,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -548,7 +548,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -583,7 +583,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -583,7 +583,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
""" """
Sets the attention processor to use [feed forward Sets the attention processor to use [feed forward
...@@ -613,7 +613,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -613,7 +613,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for module in self.children(): for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim) fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self) -> None: def disable_forward_chunking(self) -> None:
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"): if hasattr(module, "set_chunk_feed_forward"):
...@@ -625,7 +625,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -625,7 +625,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for module in self.children(): for module in self.children():
fn_recursive_feed_forward(module, None, 0) fn_recursive_feed_forward(module, None, 0)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self) -> None: def set_default_attn_processor(self) -> None:
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
...@@ -645,7 +645,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -645,7 +645,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
...@@ -670,7 +670,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -670,7 +670,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
setattr(upsample_block, "b1", b1) setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2) setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self) -> None: def disable_freeu(self) -> None:
"""Disables the FreeU mechanism.""" """Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"} freeu_keys = {"s1", "s2", "b1", "b2"}
......
...@@ -4,12 +4,12 @@ from typing import Dict, Optional, Tuple, Union ...@@ -4,12 +4,12 @@ from typing import Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ...utils import BaseOutput, logging
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
...@@ -323,7 +323,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -323,7 +323,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
if hasattr(module, "gradient_checkpointing"): if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value module.gradient_checkpointing = value
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
""" """
Sets the attention processor to use [feed forward Sets the attention processor to use [feed forward
......
...@@ -20,20 +20,20 @@ import torch.nn.functional as F ...@@ -20,20 +20,20 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ..loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from .attention import BasicTransformerBlock, SkipFFTransformerBlock from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
from .attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from .embeddings import TimestepEmbedding, get_timestep_embedding from ..embeddings import TimestepEmbedding, get_timestep_embedding
from .modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .normalization import GlobalResponseNorm, RMSNorm from ..normalization import GlobalResponseNorm, RMSNorm
from .resnet import Downsample2D, Upsample2D from ..resnet import Downsample2D, Upsample2D
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
...@@ -213,7 +213,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -213,7 +213,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return logits return logits
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
...@@ -237,7 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -237,7 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -272,7 +272,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -272,7 +272,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
......
...@@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor ...@@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
......
...@@ -36,8 +36,8 @@ from ...models.embeddings import ( ...@@ -36,8 +36,8 @@ from ...models.embeddings import (
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_blocks import DownBlock2D, UpBlock2D from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unets.unet_2d_condition import UNet2DConditionOutput
from ...utils import BaseOutput, is_torch_version, logging from ...utils import BaseOutput, is_torch_version, logging
...@@ -513,7 +513,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -513,7 +513,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
) )
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
...@@ -537,7 +537,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -537,7 +537,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -572,7 +572,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -572,7 +572,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
...@@ -588,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -588,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
self.set_attn_processor(processor) self.set_attn_processor(processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
...@@ -654,7 +654,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -654,7 +654,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for module in self.children(): for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"): if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -687,7 +687,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -687,7 +687,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens. which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple. tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
...@@ -700,8 +700,8 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -700,8 +700,8 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
which adds large negative values to the attention scores corresponding to "discard" tokens. which adds large negative values to the attention scores corresponding to "discard" tokens.
Returns: Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor. a `tuple` is returned where the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
......
...@@ -33,7 +33,7 @@ from ....models.embeddings import ( ...@@ -33,7 +33,7 @@ from ....models.embeddings import (
) )
from ....models.resnet import ResnetBlockCondNorm2D from ....models.resnet import ResnetBlockCondNorm2D
from ....models.transformer_2d import Transformer2DModel from ....models.transformer_2d import Transformer2DModel
from ....models.unet_2d_condition import UNet2DConditionOutput from ....models.unets.unet_2d_condition import UNet2DConditionOutput
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ....utils.torch_utils import apply_freeu from ....utils.torch_utils import apply_freeu
...@@ -268,6 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module): ...@@ -268,6 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
return objs return objs
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
class UNetFlatConditionModel(ModelMixin, ConfigMixin): class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r""" r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
...@@ -1095,7 +1096,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1095,7 +1096,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens. which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple. tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
...@@ -1111,8 +1112,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1111,8 +1112,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns: Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor. a `tuple` is returned where the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
...@@ -1785,7 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1785,7 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module):
return hidden_states, output_states return hidden_states, output_states
# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim # Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class UpBlockFlat(nn.Module): class UpBlockFlat(nn.Module):
def __init__( def __init__(
self, self,
...@@ -1896,7 +1897,7 @@ class UpBlockFlat(nn.Module): ...@@ -1896,7 +1897,7 @@ class UpBlockFlat(nn.Module):
return hidden_states return hidden_states
# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim # Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class CrossAttnUpBlockFlat(nn.Module): class CrossAttnUpBlockFlat(nn.Module):
def __init__( def __init__(
self, self,
...@@ -2070,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -2070,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
return hidden_states return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat # Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlat(nn.Module): class UNetMidBlockFlat(nn.Module):
""" """
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
...@@ -2226,7 +2227,7 @@ class UNetMidBlockFlat(nn.Module): ...@@ -2226,7 +2227,7 @@ class UNetMidBlockFlat(nn.Module):
return hidden_states return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat # Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module): class UNetMidBlockFlatCrossAttn(nn.Module):
def __init__( def __init__(
self, self,
...@@ -2373,7 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2373,7 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
return hidden_states return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat # Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatSimpleCrossAttn(nn.Module): class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -752,7 +752,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -752,7 +752,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
cross_attention_kwargs (*optional*): cross_attention_kwargs (*optional*):
Keyword arguments to supply to the cross attention layers, if used. Keyword arguments to supply to the cross attention layers, if used.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
......
...@@ -66,7 +66,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -66,7 +66,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
self.set_default_attn_processor() self.set_default_attn_processor()
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
...@@ -90,7 +90,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -90,7 +90,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -125,7 +125,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -125,7 +125,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from diffusers.models.unet_2d_blocks import * # noqa F403 from diffusers.models.unets.unet_2d_blocks import * # noqa F403
from diffusers.utils.testing_utils import torch_device from diffusers.utils.testing_utils import torch_device
from .test_unet_blocks_common import UNetBlockTesterMixin from .test_unet_blocks_common import UNetBlockTesterMixin
......
...@@ -28,7 +28,7 @@ from diffusers import ( ...@@ -28,7 +28,7 @@ from diffusers import (
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.models.unet_2d_blocks import UNetMidBlock2D from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment