__init__.py 3 KB
Newer Older
1
from .utils import (
2
    is_flax_available,
3
4
5
    is_inflect_available,
    is_onnx_available,
    is_scipy_available,
6
    is_torch_available,
7
8
9
    is_transformers_available,
    is_unidecode_available,
)
Patrick von Platen's avatar
Patrick von Platen committed
10
11


12
__version__ = "0.7.0.dev0"
Patrick von Platen's avatar
Patrick von Platen committed
13

Patrick von Platen's avatar
Patrick von Platen committed
14
from .configuration_utils import ConfigMixin
15
from .onnx_utils import OnnxRuntimeModel
Patrick von Platen's avatar
Patrick von Platen committed
16
from .utils import logging
17
18


19
20
if is_torch_available():
    from .modeling_utils import ModelMixin
21
    from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
22
23
24
25
26
27
28
29
30
31
    from .optimization import (
        get_constant_schedule,
        get_constant_schedule_with_warmup,
        get_cosine_schedule_with_warmup,
        get_cosine_with_hard_restarts_schedule_with_warmup,
        get_linear_schedule_with_warmup,
        get_polynomial_decay_schedule_with_warmup,
        get_scheduler,
    )
    from .pipeline_utils import DiffusionPipeline
32
33
34
35
36
37
38
39
40
    from .pipelines import (
        DanceDiffusionPipeline,
        DDIMPipeline,
        DDPMPipeline,
        KarrasVePipeline,
        LDMPipeline,
        PNDMPipeline,
        ScoreSdeVePipeline,
    )
41
42
43
    from .schedulers import (
        DDIMScheduler,
        DDPMScheduler,
44
        IPNDMScheduler,
45
46
47
48
49
50
        KarrasVeScheduler,
        PNDMScheduler,
        SchedulerMixin,
        ScoreSdeVeScheduler,
    )
    from .training_utils import EMAModel
51
else:
52
    from .utils.dummy_pt_objects import *  # noqa F403
Patrick von Platen's avatar
Patrick von Platen committed
53

54
55
56
57
if is_torch_available() and is_scipy_available():
    from .schedulers import LMSDiscreteScheduler
else:
    from .utils.dummy_torch_and_scipy_objects import *  # noqa F403
Patrick von Platen's avatar
Patrick von Platen committed
58

59
if is_torch_available() and is_transformers_available():
60
61
62
63
    from .pipelines import (
        LDMTextToImagePipeline,
        StableDiffusionImg2ImgPipeline,
        StableDiffusionInpaintPipeline,
64
        StableDiffusionInpaintPipelineLegacy,
65
66
        StableDiffusionPipeline,
    )
Patrick von Platen's avatar
Patrick von Platen committed
67
else:
68
    from .utils.dummy_torch_and_transformers_objects import *  # noqa F403
69

70
if is_torch_available() and is_transformers_available() and is_onnx_available():
71
72
73
74
75
76
    from .pipelines import (
        OnnxStableDiffusionImg2ImgPipeline,
        OnnxStableDiffusionInpaintPipeline,
        OnnxStableDiffusionPipeline,
        StableDiffusionOnnxPipeline,
    )
77
else:
78
    from .utils.dummy_torch_and_transformers_and_onnx_objects import *  # noqa F403
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

if is_flax_available():
    from .modeling_flax_utils import FlaxModelMixin
    from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
    from .models.vae_flax import FlaxAutoencoderKL
    from .pipeline_flax_utils import FlaxDiffusionPipeline
    from .schedulers import (
        FlaxDDIMScheduler,
        FlaxDDPMScheduler,
        FlaxKarrasVeScheduler,
        FlaxLMSDiscreteScheduler,
        FlaxPNDMScheduler,
        FlaxSchedulerMixin,
        FlaxScoreSdeVeScheduler,
    )
else:
    from .utils.dummy_flax_objects import *  # noqa F403

if is_flax_available() and is_transformers_available():
    from .pipelines import FlaxStableDiffusionPipeline
else:
    from .utils.dummy_flax_and_transformers_objects import *  # noqa F403