"vscode:/vscode.git/clone" did not exist on "c595039ca67f4b03826bdab29604afcb63e907d8"
Unverified Commit a0c54828 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Deprecate Pipelines (#6169)



* deprecate pipe

* make style

* update

* add deprecation message

* format

* remove tests for deprecated pipelines

* remove deprecation message

* make style

* fix copies

* clean up

* clean

* clean

* clean

* clean up

* clean up

* clean up toctree

* clean up

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 8d891e6e
...@@ -21,17 +21,17 @@ import torch ...@@ -21,17 +21,17 @@ import torch
from packaging import version from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ....configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ....image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ....loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ....models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ....models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler from ....schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -6,12 +6,12 @@ import PIL.Image ...@@ -6,12 +6,12 @@ import PIL.Image
import torch import torch
from transformers import CLIPImageProcessor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ....configuration_utils import FrozenDict
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ....schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging from ....utils import deprecate, logging
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -21,17 +21,17 @@ import torch ...@@ -21,17 +21,17 @@ import torch
from packaging import version from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ....configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ....image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ....loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ....models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ....models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ....schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from ...stable_diffusion import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -18,17 +18,17 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -18,17 +18,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ....image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ....loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ....models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ....models.lora import adjust_lora_scale_text_encoder
from ...schedulers import PNDMScheduler from ....schedulers import PNDMScheduler
from ...schedulers.scheduling_utils import SchedulerMixin from ....schedulers.scheduling_utils import SchedulerMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ....utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -18,12 +18,12 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -18,12 +18,12 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ....image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ....loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ....models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ....models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ....schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ....utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,
logging, logging,
...@@ -31,10 +31,10 @@ from ...utils import ( ...@@ -31,10 +31,10 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -28,14 +28,14 @@ from transformers import ( ...@@ -28,14 +28,14 @@ from transformers import (
CLIPTokenizer, CLIPTokenizer,
) )
from ...image_processor import PipelineImageInput, VaeImageProcessor from ....image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ....loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ....models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ....models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder from ....models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ....schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ....schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
from ...utils import ( from ....utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
BaseOutput, BaseOutput,
...@@ -45,10 +45,10 @@ from ...utils import ( ...@@ -45,10 +45,10 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
_import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]} _import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]}
......
...@@ -16,10 +16,10 @@ from typing import List, Optional, Tuple, Union ...@@ -16,10 +16,10 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...models import UNet2DModel from ....models import UNet2DModel
from ...schedulers import KarrasVeScheduler from ....schedulers import KarrasVeScheduler
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class KarrasVePipeline(DiffusionPipeline): class KarrasVePipeline(DiffusionPipeline):
......
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ....utils import (
DIFFUSERS_SLOW_IMPORT, DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
...@@ -17,7 +17,7 @@ try: ...@@ -17,7 +17,7 @@ try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ( from ....utils.dummy_torch_and_transformers_objects import (
VersatileDiffusionDualGuidedPipeline, VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline, VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline, VersatileDiffusionPipeline,
...@@ -45,7 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -45,7 +45,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ( from ....utils.dummy_torch_and_transformers_objects import (
VersatileDiffusionDualGuidedPipeline, VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline, VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline, VersatileDiffusionPipeline,
......
...@@ -7,10 +7,10 @@ import torch.nn.functional as F ...@@ -7,10 +7,10 @@ import torch.nn.functional as F
from diffusers.utils import deprecate from diffusers.utils import deprecate
from ...configuration_utils import ConfigMixin, register_to_config from ....configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin from ....models import ModelMixin
from ...models.activations import get_activation from ....models.activations import get_activation
from ...models.attention_processor import ( from ....models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention, Attention,
...@@ -19,8 +19,8 @@ from ...models.attention_processor import ( ...@@ -19,8 +19,8 @@ from ...models.attention_processor import (
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
AttnProcessor, AttnProcessor,
) )
from ...models.dual_transformer_2d import DualTransformer2DModel from ....models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import ( from ....models.embeddings import (
GaussianFourierProjection, GaussianFourierProjection,
ImageHintTimeEmbedding, ImageHintTimeEmbedding,
ImageProjection, ImageProjection,
...@@ -31,10 +31,10 @@ from ...models.embeddings import ( ...@@ -31,10 +31,10 @@ from ...models.embeddings import (
TimestepEmbedding, TimestepEmbedding,
Timesteps, Timesteps,
) )
from ...models.transformer_2d import Transformer2DModel from ....models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput from ....models.unet_2d_condition import UNet2DConditionOutput
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import apply_freeu from ....utils.torch_utils import apply_freeu
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -5,10 +5,10 @@ import PIL.Image ...@@ -5,10 +5,10 @@ import PIL.Image
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel from ....models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ....schedulers import KarrasDiffusionSchedulers
from ...utils import logging from ....utils import logging
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
......
...@@ -26,12 +26,12 @@ from transformers import ( ...@@ -26,12 +26,12 @@ from transformers import (
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
) )
from ...image_processor import VaeImageProcessor from ....image_processor import VaeImageProcessor
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ....models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ....schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging from ....utils import deprecate, logging
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
......
...@@ -21,12 +21,12 @@ import torch ...@@ -21,12 +21,12 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor from ....image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ....models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ....schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging from ....utils import deprecate, logging
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -19,12 +19,12 @@ import torch ...@@ -19,12 +19,12 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ....image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ....models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ....schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging from ....utils import deprecate, logging
from ...utils.torch_utils import randn_tensor from ....utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
......
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ....utils import (
DIFFUSERS_SLOW_IMPORT, DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
...@@ -16,7 +16,7 @@ try: ...@@ -16,7 +16,7 @@ try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ( from ....utils.dummy_torch_and_transformers_objects import (
LearnedClassifierFreeSamplingEmbeddings, LearnedClassifierFreeSamplingEmbeddings,
VQDiffusionPipeline, VQDiffusionPipeline,
) )
...@@ -36,7 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -36,7 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ( from ....utils.dummy_torch_and_transformers_objects import (
LearnedClassifierFreeSamplingEmbeddings, LearnedClassifierFreeSamplingEmbeddings,
VQDiffusionPipeline, VQDiffusionPipeline,
) )
......
...@@ -17,11 +17,11 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -17,11 +17,11 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...configuration_utils import ConfigMixin, register_to_config from ....configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin, Transformer2DModel, VQModel from ....models import ModelMixin, Transformer2DModel, VQModel
from ...schedulers import VQDiffusionScheduler from ....schedulers import VQDiffusionScheduler
from ...utils import logging from ....utils import logging
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -134,7 +134,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -134,7 +134,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .clip_image_project_model import CLIPImageProjection from .clip_image_project_model import CLIPImageProjection
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import ( from .pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionPipelineOutput, StableDiffusionPipelineOutput,
...@@ -149,9 +148,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -149,9 +148,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
) )
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import (
StableDiffusionInpaintPipelineLegacy,
)
from .pipeline_stable_diffusion_instruct_pix2pix import ( from .pipeline_stable_diffusion_instruct_pix2pix import (
StableDiffusionInstructPix2PixPipeline, StableDiffusionInstructPix2PixPipeline,
) )
...@@ -159,13 +155,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -159,13 +155,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionLatentUpscalePipeline, StableDiffusionLatentUpscalePipeline,
) )
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
from .pipeline_stable_diffusion_model_editing import (
StableDiffusionModelEditingPipeline,
)
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
from .pipeline_stable_diffusion_paradigms import (
StableDiffusionParadigmsPipeline,
)
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip import StableUnCLIPPipeline
...@@ -199,9 +189,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -199,9 +189,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
) )
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
from .pipeline_stable_diffusion_pix2pix_zero import (
StableDiffusionPix2PixZeroPipeline,
)
try: try:
if not ( if not (
...@@ -234,9 +221,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -234,9 +221,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_onnx_stable_diffusion_inpaint import ( from .pipeline_onnx_stable_diffusion_inpaint import (
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
) )
from .pipeline_onnx_stable_diffusion_inpaint_legacy import (
OnnxStableDiffusionInpaintPipelineLegacy,
)
from .pipeline_onnx_stable_diffusion_upscale import ( from .pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
......
...@@ -788,7 +788,6 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -788,7 +788,6 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents
def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError( raise ValueError(
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, XLMRobertaTokenizer
from diffusers import AltDiffusionPipeline, AutoencoderKL, DDIMScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
RobertaSeriesConfig,
RobertaSeriesModelWithTransformation,
)
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
class AltDiffusionPipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = AltDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
# TODO: address the non-deterministic text encoder (fails for save-load tests)
# torch.manual_seed(0)
# text_encoder_config = RobertaSeriesConfig(
# hidden_size=32,
# project_dim=32,
# intermediate_size=37,
# layer_norm_eps=1e-05,
# num_attention_heads=4,
# num_hidden_layers=5,
# vocab_size=5002,
# )
# text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
projection_dim=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=5002,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
tokenizer.model_max_length = 77
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
def test_alt_diffusion_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
torch.manual_seed(0)
text_encoder_config = RobertaSeriesConfig(
hidden_size=32,
project_dim=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
vocab_size=5002,
)
# TODO: remove after fixing the non-deterministic text encoder
text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
components["text_encoder"] = text_encoder
alt_pipe = AltDiffusionPipeline(**components)
alt_pipe = alt_pipe.to(device)
alt_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = "A photo of an astronaut"
output = alt_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_alt_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
torch.manual_seed(0)
text_encoder_config = RobertaSeriesConfig(
hidden_size=32,
project_dim=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
vocab_size=5002,
)
# TODO: remove after fixing the non-deterministic text encoder
text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
components["text_encoder"] = text_encoder
alt_pipe = AltDiffusionPipeline(**components)
alt_pipe = alt_pipe.to(device)
alt_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = alt_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@nightly
@require_torch_gpu
class AltDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_alt_diffusion(self):
# make sure here that pndm scheduler skips prk
alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", safety_checker=None)
alt_pipe = alt_pipe.to(torch_device)
alt_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_alt_diffusion_fast_ddim(self):
scheduler = DDIMScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")
alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=scheduler, safety_checker=None)
alt_pipe = alt_pipe.to(torch_device)
alt_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.4019, 0.4052, 0.3810, 0.4119, 0.3916, 0.3982, 0.4651, 0.4195, 0.5323])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment