"docs/vscode:/vscode.git/clone" did not exist on "db1cb0b1a233cc6f4029261a67e503b322f31cd0"
Unverified Commit 29b2c93c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make repo structure consistent (#1862)



* move files a bit

* more refactors

* fix more

* more fixes

* fix more onnx

* make style

* upload

* fix

* up

* fix more

* up again

* up

* small fix

* Update src/diffusers/__init__.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* correct
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent ab0e92fd
...@@ -23,7 +23,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -23,7 +23,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
...@@ -33,6 +32,7 @@ from ...schedulers import ( ...@@ -33,6 +32,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import deprecate, logging
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -26,7 +26,6 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTF ...@@ -26,7 +26,6 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTF
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
...@@ -36,6 +35,7 @@ from ...schedulers import ( ...@@ -36,6 +35,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -24,7 +24,6 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection ...@@ -24,7 +24,6 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
...@@ -34,6 +33,7 @@ from ...schedulers import ( ...@@ -34,6 +33,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import deprecate, logging
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
...@@ -35,6 +34,7 @@ from ...schedulers import ( ...@@ -35,6 +34,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -25,9 +25,9 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -25,9 +25,9 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging from ...utils import deprecate, logging
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
...@@ -35,6 +34,7 @@ from ...schedulers import ( ...@@ -35,6 +34,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from ... import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
......
...@@ -23,9 +23,9 @@ from diffusers.utils import is_accelerate_available ...@@ -23,9 +23,9 @@ from diffusers.utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import logging
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -10,7 +10,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -10,7 +10,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
...@@ -20,6 +19,7 @@ from ...schedulers import ( ...@@ -20,6 +19,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, is_accelerate_available, logging from ...utils import deprecate, is_accelerate_available, logging
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionSafePipelineOutput from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker from .safety_checker import SafeStableDiffusionSafetyChecker
......
...@@ -17,8 +17,8 @@ from typing import List, Optional, Tuple, Union ...@@ -17,8 +17,8 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import KarrasVeScheduler from ...schedulers import KarrasVeScheduler
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class KarrasVePipeline(DiffusionPipeline): class KarrasVePipeline(DiffusionPipeline):
...@@ -68,12 +68,11 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -68,12 +68,11 @@ class KarrasVePipeline(DiffusionPipeline):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Returns: Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
generated images.
""" """
img_size = self.unet.config.sample_size img_size = self.unet.config.sample_size
......
...@@ -18,11 +18,11 @@ from typing import List, Optional, Union ...@@ -18,11 +18,11 @@ from typing import List, Optional, Union
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.schedulers import UnCLIPScheduler
from transformers import CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
...@@ -291,7 +291,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -291,7 +291,7 @@ class UnCLIPPipeline(DiffusionPipeline):
The output format of the generated image. Choose between The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
""" """
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -19,9 +19,6 @@ import torch ...@@ -19,9 +19,6 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
import PIL import PIL
from diffusers import UNet2DConditionModel, UNet2DModel
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.schedulers import UnCLIPScheduler
from transformers import ( from transformers import (
CLIPFeatureExtractor, CLIPFeatureExtractor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
...@@ -29,6 +26,9 @@ from transformers import ( ...@@ -29,6 +26,9 @@ from transformers import (
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
) )
from ...models import UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
...@@ -303,7 +303,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -303,7 +303,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
The output format of the generated image. Choose between The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
""" """
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
batch_size = 1 batch_size = 1
......
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
import torch import torch
from torch import nn from torch import nn
from diffusers.modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
class UnCLIPTextProjModel(ModelMixin, ConfigMixin): class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...modeling_utils import ModelMixin from ...models import ModelMixin
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
from ...models.embeddings import TimestepEmbedding, Timesteps from ...models.embeddings import TimestepEmbedding, Timesteps
......
...@@ -7,9 +7,9 @@ import PIL.Image ...@@ -7,9 +7,9 @@ import PIL.Image
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import logging
from ..pipeline_utils import DiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
......
...@@ -27,11 +27,10 @@ from transformers import ( ...@@ -27,11 +27,10 @@ from transformers import (
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
) )
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...models.attention import DualTransformer2DModel, Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
......
...@@ -23,9 +23,9 @@ import PIL ...@@ -23,9 +23,9 @@ import PIL
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -20,11 +20,10 @@ import torch.utils.checkpoint ...@@ -20,11 +20,10 @@ import torch.utils.checkpoint
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...models.attention import Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
......
...@@ -16,14 +16,13 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -16,14 +16,13 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from diffusers import Transformer2DModel, VQModel
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...modeling_utils import ModelMixin from ...models import ModelMixin, Transformer2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import VQDiffusionScheduler
from ...utils import logging from ...utils import logging
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -212,7 +211,7 @@ class VQDiffusionPipeline(DiffusionPipeline): ...@@ -212,7 +211,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
The output format of the generated image. Choose between The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
...@@ -221,9 +220,8 @@ class VQDiffusionPipeline(DiffusionPipeline): ...@@ -221,9 +220,8 @@ class VQDiffusionPipeline(DiffusionPipeline):
called at every step. called at every step.
Returns: Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict`
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
generated images.
""" """
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -18,7 +18,22 @@ import os ...@@ -18,7 +18,22 @@ import os
from packaging import version from packaging import version
from .. import __version__ from .. import __version__
from .constants import (
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CONFIG_NAME,
DIFFUSERS_CACHE,
DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME,
HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import HF_HUB_OFFLINE, http_user_agent
from .import_utils import ( from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES, ENV_VARS_TRUE_VALUES,
...@@ -67,36 +82,6 @@ if is_torch_available(): ...@@ -67,36 +82,6 @@ if is_torch_available():
logger = get_logger(__name__) logger = get_logger(__name__)
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "diffusers")
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"HeunDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
"DPMSolverSinglestepScheduler",
]
def check_min_version(min_version): def check_min_version(min_version):
if version.parse(__version__) < version.parse(min_version): if version.parse(__version__) < version.parse(min_version):
if "dev" in min_version: if "dev" in min_version:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment