Unverified Commit 09b7bfce authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] move transformer scripts to `transformers` modules (#6747)

* move transformer scripts to transformers modules

* move transformer model test

* move prior transformer test to  directory

* fix doc path

* correct doc path

* add: __init__.py
parent 5d8b1987
...@@ -33,7 +33,7 @@ from ..attention_processor import ( ...@@ -33,7 +33,7 @@ from ..attention_processor import (
) )
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..transformer_temporal import TransformerTemporalModel from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import ( from .unet_3d_blocks import (
CrossAttnDownBlock3D, CrossAttnDownBlock3D,
CrossAttnUpBlock3D, CrossAttnUpBlock3D,
......
...@@ -29,7 +29,7 @@ from ..attention_processor import ( ...@@ -29,7 +29,7 @@ from ..attention_processor import (
) )
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..transformer_temporal import TransformerTemporalModel from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_blocks import ( from .unet_3d_blocks import (
......
...@@ -35,7 +35,7 @@ from ...models.embeddings import ( ...@@ -35,7 +35,7 @@ from ...models.embeddings import (
) )
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ...models.transformer_2d import Transformer2DModel from ...models.transformers.transformer_2d import Transformer2DModel
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unets.unet_2d_condition import UNet2DConditionOutput from ...models.unets.unet_2d_condition import UNet2DConditionOutput
from ...utils import BaseOutput, is_torch_version, logging from ...utils import BaseOutput, is_torch_version, logging
......
...@@ -19,7 +19,6 @@ from ....models.attention_processor import ( ...@@ -19,7 +19,6 @@ from ....models.attention_processor import (
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
AttnProcessor, AttnProcessor,
) )
from ....models.dual_transformer_2d import DualTransformer2DModel
from ....models.embeddings import ( from ....models.embeddings import (
GaussianFourierProjection, GaussianFourierProjection,
ImageHintTimeEmbedding, ImageHintTimeEmbedding,
...@@ -32,7 +31,8 @@ from ....models.embeddings import ( ...@@ -32,7 +31,8 @@ from ....models.embeddings import (
Timesteps, Timesteps,
) )
from ....models.resnet import ResnetBlockCondNorm2D from ....models.resnet import ResnetBlockCondNorm2D
from ....models.transformer_2d import Transformer2DModel from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
from ....models.transformers.transformer_2d import Transformer2DModel
from ....models.unets.unet_2d_condition import UNet2DConditionOutput from ....models.unets.unet_2d_condition import UNet2DConditionOutput
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ....utils.torch_utils import apply_freeu from ....utils.torch_utils import apply_freeu
......
...@@ -10,7 +10,7 @@ from ...models.attention import FeedForward ...@@ -10,7 +10,7 @@ from ...models.attention import FeedForward
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from ...models.normalization import AdaLayerNorm from ...models.normalization import AdaLayerNorm
from ...models.transformer_2d import Transformer2DModelOutput from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import logging from ...utils import logging
......
...@@ -24,7 +24,7 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU ...@@ -24,7 +24,7 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.lora import LoRACompatibleLinear from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel from diffusers.models.transformers.transformer_2d import Transformer2DModel
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_manual_seed, backend_manual_seed,
require_torch_accelerator_with_fp64, require_torch_accelerator_with_fp64,
......
...@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import ( ...@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from .test_modeling_common import ModelTesterMixin from ..test_modeling_common import ModelTesterMixin
enable_full_determinism() enable_full_determinism()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment