"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6fd458e99d0b465bea6a8002aff5357514862751"
Unverified Commit 1f0705ad authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Big refactor] move unets to `unets` module 🦋 (#6630)

* move unets to  module 🦋

* parameterize unet-level import.

* fix flax unet2dcondition model import

* models __init__

* mildly depcrecating models.unet_2d_blocks in favor of models.unets.unet_2d_blocks.

* noqa

* correct depcrecation behaviour

* inherit from the actual classes.

* Empty-Commit

* backwards compatibility for unet_2d.py

* backward compatibility for unet_2d_condition

* bc for unet_1d

* bc for unet_1d_blocks
parent 5e96333c
...@@ -22,4 +22,4 @@ The abstract from the paper is: ...@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNetMotionModel [[autodoc]] UNetMotionModel
## UNet3DConditionOutput ## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput [[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
...@@ -22,4 +22,4 @@ The abstract from the paper is: ...@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet1DModel [[autodoc]] UNet1DModel
## UNet1DOutput ## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput [[autodoc]] models.unets.unet_1d.UNet1DOutput
...@@ -22,10 +22,10 @@ The abstract from the paper is: ...@@ -22,10 +22,10 @@ The abstract from the paper is:
[[autodoc]] UNet2DConditionModel [[autodoc]] UNet2DConditionModel
## UNet2DConditionOutput ## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput [[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput
## FlaxUNet2DConditionModel ## FlaxUNet2DConditionModel
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel [[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel
## FlaxUNet2DConditionOutput ## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput [[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput
...@@ -22,4 +22,4 @@ The abstract from the paper is: ...@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet2DModel [[autodoc]] UNet2DModel
## UNet2DOutput ## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput [[autodoc]] models.unets.unet_2d.UNet2DOutput
...@@ -22,4 +22,4 @@ The abstract from the paper is: ...@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet3DConditionModel [[autodoc]] UNet3DConditionModel
## UNet3DConditionOutput ## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput [[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
...@@ -26,7 +26,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor ...@@ -26,7 +26,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unet_motion_model import MotionAdapter from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import ( from diffusers.schedulers import (
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from diffusers import StableDiffusionControlNetPipeline from diffusers import StableDiffusionControlNetPipeline
from diffusers.models import ControlNetModel from diffusers.models import ControlNetModel
from diffusers.models.attention import BasicTransformerBlock from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import logging from diffusers.utils import logging
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers.models.attention import BasicTransformerBlock from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import PIL_INTERPOLATION, logging from diffusers.utils import PIL_INTERPOLATION, logging
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from diffusers import StableDiffusionXLPipeline from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention import BasicTransformerBlock from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import ( from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
CrossAttnUpBlock2D, CrossAttnUpBlock2D,
DownBlock2D, DownBlock2D,
......
...@@ -26,7 +26,7 @@ from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProc ...@@ -26,7 +26,7 @@ from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProc
from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.lora import LoRACompatibleConv from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import ( from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
CrossAttnUpBlock2D, CrossAttnUpBlock2D,
DownBlock2D, DownBlock2D,
...@@ -36,7 +36,7 @@ from diffusers.models.unet_2d_blocks import ( ...@@ -36,7 +36,7 @@ from diffusers.models.unet_2d_blocks import (
UpBlock2D, UpBlock2D,
Upsample2D, Upsample2D,
) )
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.utils import BaseOutput, logging from diffusers.utils import BaseOutput, logging
......
...@@ -10,7 +10,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer ...@@ -10,7 +10,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import VQModel from diffusers import VQModel
from diffusers.models.attention_processor import AttnProcessor from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.uvit_2d import UVit2DModel from diffusers.models.unets.uvit_2d import UVit2DModel
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
from diffusers.schedulers import AmusedScheduler from diffusers.schedulers import AmusedScheduler
......
...@@ -14,7 +14,7 @@ from tqdm import tqdm ...@@ -14,7 +14,7 @@ from tqdm import tqdm
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
from diffusers.models.autoencoders.vae import Encoder from diffusers.models.autoencoders.vae import Encoder
from diffusers.models.embeddings import TimestepEmbedding from diffusers.models.embeddings import TimestepEmbedding
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
args = ArgumentParser() args = ArgumentParser()
......
...@@ -382,7 +382,7 @@ except OptionalDependencyNotAvailable: ...@@ -382,7 +382,7 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend( _import_structure["schedulers"].extend(
...@@ -711,7 +711,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -711,7 +711,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .models.controlnet_flax import FlaxControlNetModel from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline from .pipelines import FlaxDiffusionPipeline
from .schedulers import ( from .schedulers import (
......
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from ...models.unet_1d import UNet1DModel from ...models.unets.unet_1d import UNet1DModel
from ...pipelines import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...utils.dummy_pt_objects import DDPMScheduler from ...utils.dummy_pt_objects import DDPMScheduler
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
......
...@@ -39,19 +39,19 @@ if is_torch_available(): ...@@ -39,19 +39,19 @@ if is_torch_available():
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"] _import_structure["transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["uvit_2d"] = ["UVit2DModel"] _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["vq_model"] = ["VQModel"] _import_structure["vq_model"] = ["VQModel"]
if is_flax_available(): if is_flax_available():
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
...@@ -73,19 +73,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -73,19 +73,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .t5_film_transformer import T5FilmDecoder from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel from .transformer_temporal import TransformerTemporalModel
from .unet_1d import UNet1DModel from .unets import (
from .unet_2d import UNet2DModel Kandinsky3UNet,
from .unet_2d_condition import UNet2DConditionModel MotionAdapter,
from .unet_3d_condition import UNet3DConditionModel UNet1DModel,
from .unet_kandinsky3 import Kandinsky3UNet UNet2DConditionModel,
from .unet_motion_model import MotionAdapter, UNetMotionModel UNet2DModel,
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel UNet3DConditionModel,
from .uvit_2d import UVit2DModel UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
)
from .vq_model import VQModel from .vq_model import VQModel
if is_flax_available(): if is_flax_available():
from .controlnet_flax import FlaxControlNetModel from .controlnet_flax import FlaxControlNetModel
from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL from .vae_flax import FlaxAutoencoderKL
else: else:
......
...@@ -157,7 +157,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -157,7 +157,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
self.use_slicing = False self.use_slicing = False
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
...@@ -181,7 +181,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -181,7 +181,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -216,7 +216,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -216,7 +216,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
...@@ -448,7 +448,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -448,7 +448,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self): def fuse_qkv_projections(self):
""" """
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
...@@ -472,7 +472,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -472,7 +472,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -23,7 +23,7 @@ from ...utils.accelerate_utils import apply_forward_hook ...@@ -23,7 +23,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
...@@ -242,7 +242,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin ...@@ -242,7 +242,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
module.gradient_checkpointing = value module.gradient_checkpointing = value
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
...@@ -266,7 +266,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin ...@@ -266,7 +266,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
......
...@@ -31,7 +31,7 @@ from ..attention_processor import ( ...@@ -31,7 +31,7 @@ from ..attention_processor import (
AttnProcessor, AttnProcessor,
) )
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..unet_2d import UNet2DModel from ..unets.unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
...@@ -187,7 +187,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -187,7 +187,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.use_slicing = False self.use_slicing = False
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
...@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
......
...@@ -22,7 +22,7 @@ from ...utils import BaseOutput, is_torch_version ...@@ -22,7 +22,7 @@ from ...utils import BaseOutput, is_torch_version
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..activations import get_activation from ..activations import get_activation
from ..attention_processor import SpatialNorm from ..attention_processor import SpatialNorm
from ..unet_2d_blocks import ( from ..unets.unet_2d_blocks import (
AutoencoderTinyBlock, AutoencoderTinyBlock,
UNetMidBlock2D, UNetMidBlock2D,
get_down_block, get_down_block,
......
...@@ -30,8 +30,14 @@ from .attention_processor import ( ...@@ -30,8 +30,14 @@ from .attention_processor import (
) )
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block from .unets.unet_2d_blocks import (
from .unet_2d_condition import UNet2DConditionModel CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from .unets.unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -509,7 +515,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -509,7 +515,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return controlnet return controlnet
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
...@@ -533,7 +539,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -533,7 +539,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
...@@ -568,7 +574,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -568,7 +574,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self): def set_default_attn_processor(self):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
...@@ -584,7 +590,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): ...@@ -584,7 +590,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
self.set_attn_processor(processor) self.set_attn_processor(processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment