Unverified Commit f108ad88 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Update modeling imports (#11129)

update
parent e30d3bf5
...@@ -20,12 +20,12 @@ import torch.nn as nn ...@@ -20,12 +20,12 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...models.attention_processor import AttentionProcessor
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import AttentionProcessor
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
......
...@@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
from ...models.modeling_utils import ModelMixin
from ...utils import logging from ...utils import logging
from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from ...models.controlnets.controlnet import ControlNetOutput
from ...models.controlnets.controlnet_union import ControlNetUnionModel
from ...models.modeling_utils import ModelMixin
from ...utils import logging from ...utils import logging
from ..controlnets.controlnet import ControlNetOutput
from ..controlnets.controlnet_union import ControlNetUnionModel
from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -18,10 +18,9 @@ import torch ...@@ -18,10 +18,9 @@ import torch
from torch import nn from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock from ..attention import BasicTransformerBlock
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle from ..normalization import AdaLayerNormSingle
......
...@@ -21,16 +21,12 @@ import torch.nn as nn ...@@ -21,16 +21,12 @@ import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
StableAudioAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import logging from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
from ..modeling_utils import ModelMixin
from ..transformers.transformer_2d import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -19,18 +19,13 @@ import torch ...@@ -19,18 +19,13 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
CogVideoXAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import logging from ...utils import logging
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -21,22 +21,22 @@ import torch.nn as nn ...@@ -21,22 +21,22 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...models.attention_processor import ( from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import (
Attention, Attention,
AttentionProcessor, AttentionProcessor,
FluxAttnProcessor2_0, FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU, FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0, FusedFluxAttnProcessor2_0,
) )
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..cache_utils import CacheMixin from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -18,19 +18,19 @@ import torch.nn as nn ...@@ -18,19 +18,19 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...models.attention import FeedForward, JointTransformerBlock from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...models.attention_processor import ( from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward, JointTransformerBlock
from ..attention_processor import (
Attention, Attention,
AttentionProcessor, AttentionProcessor,
FusedJointAttnProcessor2_0, FusedJointAttnProcessor2_0,
JointAttnProcessor2_0, JointAttnProcessor2_0,
) )
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment