Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
721e0174
Unverified
Commit
721e0174
authored
Sep 13, 2022
by
Patrick von Platen
Committed by
GitHub
Sep 13, 2022
Browse files
[Flax] Make room for more frameworks (#494)
* start * finish
parent
f4781a0b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
227 additions
and
57 deletions
+227
-57
setup.py
setup.py
+14
-4
src/diffusers/__init__.py
src/diffusers/__init__.py
+33
-31
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+3
-0
src/diffusers/utils/dummy_pt_objects.py
src/diffusers/utils/dummy_pt_objects.py
+165
-0
src/diffusers/utils/dummy_torch_and_scipy_objects.py
src/diffusers/utils/dummy_torch_and_scipy_objects.py
+2
-2
src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
...rs/utils/dummy_torch_and_transformers_and_onnx_objects.py
+2
-2
src/diffusers/utils/dummy_torch_and_transformers_objects.py
src/diffusers/utils/dummy_torch_and_transformers_objects.py
+8
-8
src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
...s/dummy_transformers_and_inflect_and_unidecode_objects.py
+0
-10
No files found.
setup.py
View file @
721e0174
...
@@ -68,6 +68,7 @@ To create the package for pypi.
...
@@ -68,6 +68,7 @@ To create the package for pypi.
"""
"""
import
re
import
re
import
os
from
distutils.core
import
Command
from
distutils.core
import
Command
from
setuptools
import
find_packages
,
setup
from
setuptools
import
find_packages
,
setup
...
@@ -82,10 +83,13 @@ _deps = [
...
@@ -82,10 +83,13 @@ _deps = [
"datasets"
,
"datasets"
,
"filelock"
,
"filelock"
,
"flake8>=3.8.3"
,
"flake8>=3.8.3"
,
"flax>=0.4.1"
,
"hf-doc-builder>=0.3.0"
,
"hf-doc-builder>=0.3.0"
,
"huggingface-hub>=0.8.1"
,
"huggingface-hub>=0.8.1"
,
"importlib_metadata"
,
"importlib_metadata"
,
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"jax>=0.2.8,!=0.3.2,<=0.3.6"
,
"jaxlib>=0.1.65,<=0.3.6"
,
"modelcards==0.1.4"
,
"modelcards==0.1.4"
,
"numpy"
,
"numpy"
,
"pytest"
,
"pytest"
,
...
@@ -171,7 +175,14 @@ extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-bui
...
@@ -171,7 +175,14 @@ extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-bui
extras
[
"docs"
]
=
[
"hf-doc-builder"
]
extras
[
"docs"
]
=
[
"hf-doc-builder"
]
extras
[
"training"
]
=
[
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
]
extras
[
"training"
]
=
[
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
]
extras
[
"test"
]
=
[
"datasets"
,
"onnxruntime"
,
"pytest"
,
"pytest-timeout"
,
"pytest-xdist"
,
"scipy"
,
"transformers"
]
extras
[
"test"
]
=
[
"datasets"
,
"onnxruntime"
,
"pytest"
,
"pytest-timeout"
,
"pytest-xdist"
,
"scipy"
,
"transformers"
]
extras
[
"dev"
]
=
extras
[
"quality"
]
+
extras
[
"test"
]
+
extras
[
"training"
]
+
extras
[
"docs"
]
extras
[
"torch"
]
=
deps_list
(
"torch"
)
if
os
.
name
==
"nt"
:
# windows
extras
[
"flax"
]
=
[]
# jax is not supported on windows
else
:
extras
[
"flax"
]
=
deps_list
(
"jax"
,
"jaxlib"
,
"flax"
)
extras
[
"dev"
]
=
extras
[
"quality"
]
+
extras
[
"test"
]
+
extras
[
"training"
]
+
extras
[
"docs"
]
+
extras
[
"torch"
]
+
extras
[
"flax"
]
install_requires
=
[
install_requires
=
[
deps
[
"importlib_metadata"
],
deps
[
"importlib_metadata"
],
...
@@ -180,13 +191,12 @@ install_requires = [
...
@@ -180,13 +191,12 @@ install_requires = [
deps
[
"numpy"
],
deps
[
"numpy"
],
deps
[
"regex"
],
deps
[
"regex"
],
deps
[
"requests"
],
deps
[
"requests"
],
deps
[
"torch"
],
deps
[
"Pillow"
],
deps
[
"Pillow"
],
]
]
setup
(
setup
(
name
=
"diffusers"
,
name
=
"diffusers"
,
version
=
"0.4.0.dev0"
,
# expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version
=
"0.4.0.dev0"
,
# expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description
=
"Diffusers"
,
description
=
"Diffusers"
,
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
...
@@ -198,7 +208,7 @@ setup(
...
@@ -198,7 +208,7 @@ setup(
package_dir
=
{
""
:
"src"
},
package_dir
=
{
""
:
"src"
},
packages
=
find_packages
(
"src"
),
packages
=
find_packages
(
"src"
),
include_package_data
=
True
,
include_package_data
=
True
,
python_requires
=
">=3.
6
.0"
,
python_requires
=
">=3.
7
.0"
,
install_requires
=
install_requires
,
install_requires
=
install_requires
,
extras_require
=
extras
,
extras_require
=
extras
,
entry_points
=
{
"console_scripts"
:
[
"diffusers-cli=diffusers.commands.diffusers_cli:main"
]},
entry_points
=
{
"console_scripts"
:
[
"diffusers-cli=diffusers.commands.diffusers_cli:main"
]},
...
...
src/diffusers/__init__.py
View file @
721e0174
...
@@ -2,6 +2,7 @@ from .utils import (
...
@@ -2,6 +2,7 @@ from .utils import (
is_inflect_available
,
is_inflect_available
,
is_onnx_available
,
is_onnx_available
,
is_scipy_available
,
is_scipy_available
,
is_torch_available
,
is_transformers_available
,
is_transformers_available
,
is_unidecode_available
,
is_unidecode_available
,
)
)
...
@@ -10,40 +11,42 @@ from .utils import (
...
@@ -10,40 +11,42 @@ from .utils import (
__version__
=
"0.4.0.dev0"
__version__
=
"0.4.0.dev0"
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
AutoencoderKL
,
UNet2DConditionModel
,
UNet2DModel
,
VQModel
from
.onnx_utils
import
OnnxRuntimeModel
from
.onnx_utils
import
OnnxRuntimeModel
from
.optimization
import
(
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
get_cosine_with_hard_restarts_schedule_with_warmup
,
get_linear_schedule_with_warmup
,
get_polynomial_decay_schedule_with_warmup
,
get_scheduler
,
)
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIMPipeline
,
DDPMPipeline
,
KarrasVePipeline
,
LDMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
from
.schedulers
import
(
DDIMScheduler
,
DDPMScheduler
,
KarrasVeScheduler
,
PNDMScheduler
,
SchedulerMixin
,
ScoreSdeVeScheduler
,
)
from
.utils
import
logging
from
.utils
import
logging
if
is_scipy_available
():
if
is_torch_available
():
from
.schedulers
import
LMSDiscreteScheduler
from
.modeling_utils
import
ModelMixin
from
.models
import
AutoencoderKL
,
UNet2DConditionModel
,
UNet2DModel
,
VQModel
from
.optimization
import
(
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
get_cosine_with_hard_restarts_schedule_with_warmup
,
get_linear_schedule_with_warmup
,
get_polynomial_decay_schedule_with_warmup
,
get_scheduler
,
)
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIMPipeline
,
DDPMPipeline
,
KarrasVePipeline
,
LDMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
from
.schedulers
import
(
DDIMScheduler
,
DDPMScheduler
,
KarrasVeScheduler
,
PNDMScheduler
,
SchedulerMixin
,
ScoreSdeVeScheduler
,
)
from
.training_utils
import
EMAModel
else
:
else
:
from
.utils.dummy_scipy_objects
import
*
# noqa F403
from
.utils.dummy_pt_objects
import
*
# noqa F403
from
.training_utils
import
EMAModel
if
is_torch_available
()
and
is_scipy_available
():
from
.schedulers
import
LMSDiscreteScheduler
else
:
from
.utils.dummy_torch_and_scipy_objects
import
*
# noqa F403
if
is_transformers_available
():
if
is_torch_available
()
and
is_transformers_available
():
from
.pipelines
import
(
from
.pipelines
import
(
LDMTextToImagePipeline
,
LDMTextToImagePipeline
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionImg2ImgPipeline
,
...
@@ -51,10 +54,9 @@ if is_transformers_available():
...
@@ -51,10 +54,9 @@ if is_transformers_available():
StableDiffusionPipeline
,
StableDiffusionPipeline
,
)
)
else
:
else
:
from
.utils.dummy_transformers_objects
import
*
# noqa F403
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
if
is_transformers_available
()
and
is_onnx_available
():
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
from
.pipelines
import
StableDiffusionOnnxPipeline
from
.pipelines
import
StableDiffusionOnnxPipeline
else
:
else
:
from
.utils.dummy_transformers_and_onnx_objects
import
*
# noqa F403
from
.utils.dummy_
torch_and_
transformers_and_onnx_objects
import
*
# noqa F403
src/diffusers/dependency_versions_table.py
View file @
721e0174
...
@@ -8,10 +8,13 @@ deps = {
...
@@ -8,10 +8,13 @@ deps = {
"datasets"
:
"datasets"
,
"datasets"
:
"datasets"
,
"filelock"
:
"filelock"
,
"filelock"
:
"filelock"
,
"flake8"
:
"flake8>=3.8.3"
,
"flake8"
:
"flake8>=3.8.3"
,
"flax"
:
"flax>=0.4.1"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"huggingface-hub"
:
"huggingface-hub>=0.8.1"
,
"huggingface-hub"
:
"huggingface-hub>=0.8.1"
,
"importlib_metadata"
:
"importlib_metadata"
,
"importlib_metadata"
:
"importlib_metadata"
,
"isort"
:
"isort>=5.5.4"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.2.8,!=0.3.2,<=0.3.6"
,
"jaxlib"
:
"jaxlib>=0.1.65,<=0.3.6"
,
"modelcards"
:
"modelcards==0.1.4"
,
"modelcards"
:
"modelcards==0.1.4"
,
"numpy"
:
"numpy"
,
"numpy"
:
"numpy"
,
"pytest"
:
"pytest"
,
"pytest"
:
"pytest"
,
...
...
src/diffusers/utils/dummy_pt_objects.py
0 → 100644
View file @
721e0174
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
ModelMixin
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
AutoencoderKL
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
UNet2DConditionModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
UNet2DModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
VQModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
def
get_constant_schedule
(
*
args
,
**
kwargs
):
requires_backends
(
get_constant_schedule
,
[
"torch"
])
def
get_constant_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_constant_schedule_with_warmup
,
[
"torch"
])
def
get_cosine_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_cosine_schedule_with_warmup
,
[
"torch"
])
def
get_cosine_with_hard_restarts_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_cosine_with_hard_restarts_schedule_with_warmup
,
[
"torch"
])
def
get_linear_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_linear_schedule_with_warmup
,
[
"torch"
])
def
get_polynomial_decay_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_polynomial_decay_schedule_with_warmup
,
[
"torch"
])
def
get_scheduler
(
*
args
,
**
kwargs
):
requires_backends
(
get_scheduler
,
[
"torch"
])
class
DiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DDIMPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DDPMPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
KarrasVePipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
LDMPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
PNDMPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
ScoreSdeVePipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DDIMScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DDPMScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
KarrasVeScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
PNDMScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
SchedulerMixin
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
ScoreSdeVeScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
EMAModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
src/diffusers/utils/dummy_scipy_objects.py
→
src/diffusers/utils/dummy_
torch_and_
scipy_objects.py
View file @
721e0174
...
@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
...
@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
class
LMSDiscreteScheduler
(
metaclass
=
DummyObject
):
class
LMSDiscreteScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"scipy"
]
_backends
=
[
"torch"
,
"scipy"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"scipy"
])
requires_backends
(
self
,
[
"torch"
,
"scipy"
])
src/diffusers/utils/dummy_transformers_and_onnx_objects.py
→
src/diffusers/utils/dummy_
torch_and_
transformers_and_onnx_objects.py
View file @
721e0174
...
@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
...
@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
class
StableDiffusionOnnxPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionOnnxPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
,
"onnx"
]
_backends
=
[
"torch"
,
"transformers"
,
"onnx"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
,
"onnx"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
,
"onnx"
])
src/diffusers/utils/dummy_transformers_objects.py
→
src/diffusers/utils/dummy_
torch_and_
transformers_objects.py
View file @
721e0174
...
@@ -5,28 +5,28 @@ from ..utils import DummyObject, requires_backends
...
@@ -5,28 +5,28 @@ from ..utils import DummyObject, requires_backends
class
LDMTextToImagePipeline
(
metaclass
=
DummyObject
):
class
LDMTextToImagePipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"torch"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
])
class
StableDiffusionImg2ImgPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionImg2ImgPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"torch"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
])
class
StableDiffusionInpaintPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionInpaintPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"torch"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
])
class
StableDiffusionPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"torch"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
])
src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
deleted
100644 → 0
View file @
f4781a0b
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
GradTTSPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
,
"inflect"
,
"unidecode"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
,
"inflect"
,
"unidecode"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment