Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
721e0174
Unverified
Commit
721e0174
authored
Sep 13, 2022
by
Patrick von Platen
Committed by
GitHub
Sep 13, 2022
Browse files
[Flax] Make room for more frameworks (#494)
* start * finish
parent
f4781a0b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
227 additions
and
57 deletions
+227
-57
setup.py
setup.py
+14
-4
src/diffusers/__init__.py
src/diffusers/__init__.py
+33
-31
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+3
-0
src/diffusers/utils/dummy_pt_objects.py
src/diffusers/utils/dummy_pt_objects.py
+165
-0
src/diffusers/utils/dummy_torch_and_scipy_objects.py
src/diffusers/utils/dummy_torch_and_scipy_objects.py
+2
-2
src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
...rs/utils/dummy_torch_and_transformers_and_onnx_objects.py
+2
-2
src/diffusers/utils/dummy_torch_and_transformers_objects.py
src/diffusers/utils/dummy_torch_and_transformers_objects.py
+8
-8
src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
...s/dummy_transformers_and_inflect_and_unidecode_objects.py
+0
-10
No files found.
setup.py
View file @
721e0174
...
@@ -68,6 +68,7 @@ To create the package for pypi.
...
@@ -68,6 +68,7 @@ To create the package for pypi.
"""
"""
import
re
import
re
import
os
from
distutils.core
import
Command
from
distutils.core
import
Command
from
setuptools
import
find_packages
,
setup
from
setuptools
import
find_packages
,
setup
...
@@ -82,10 +83,13 @@ _deps = [
...
@@ -82,10 +83,13 @@ _deps = [
"datasets"
,
"datasets"
,
"filelock"
,
"filelock"
,
"flake8>=3.8.3"
,
"flake8>=3.8.3"
,
"flax>=0.4.1"
,
"hf-doc-builder>=0.3.0"
,
"hf-doc-builder>=0.3.0"
,
"huggingface-hub>=0.8.1"
,
"huggingface-hub>=0.8.1"
,
"importlib_metadata"
,
"importlib_metadata"
,
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"jax>=0.2.8,!=0.3.2,<=0.3.6"
,
"jaxlib>=0.1.65,<=0.3.6"
,
"modelcards==0.1.4"
,
"modelcards==0.1.4"
,
"numpy"
,
"numpy"
,
"pytest"
,
"pytest"
,
...
@@ -171,7 +175,14 @@ extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-bui
...
@@ -171,7 +175,14 @@ extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-bui
extras
[
"docs"
]
=
[
"hf-doc-builder"
]
extras
[
"docs"
]
=
[
"hf-doc-builder"
]
extras
[
"training"
]
=
[
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
]
extras
[
"training"
]
=
[
"accelerate"
,
"datasets"
,
"tensorboard"
,
"modelcards"
]
extras
[
"test"
]
=
[
"datasets"
,
"onnxruntime"
,
"pytest"
,
"pytest-timeout"
,
"pytest-xdist"
,
"scipy"
,
"transformers"
]
extras
[
"test"
]
=
[
"datasets"
,
"onnxruntime"
,
"pytest"
,
"pytest-timeout"
,
"pytest-xdist"
,
"scipy"
,
"transformers"
]
extras
[
"dev"
]
=
extras
[
"quality"
]
+
extras
[
"test"
]
+
extras
[
"training"
]
+
extras
[
"docs"
]
extras
[
"torch"
]
=
deps_list
(
"torch"
)
if
os
.
name
==
"nt"
:
# windows
extras
[
"flax"
]
=
[]
# jax is not supported on windows
else
:
extras
[
"flax"
]
=
deps_list
(
"jax"
,
"jaxlib"
,
"flax"
)
extras
[
"dev"
]
=
extras
[
"quality"
]
+
extras
[
"test"
]
+
extras
[
"training"
]
+
extras
[
"docs"
]
+
extras
[
"torch"
]
+
extras
[
"flax"
]
install_requires
=
[
install_requires
=
[
deps
[
"importlib_metadata"
],
deps
[
"importlib_metadata"
],
...
@@ -180,13 +191,12 @@ install_requires = [
...
@@ -180,13 +191,12 @@ install_requires = [
deps
[
"numpy"
],
deps
[
"numpy"
],
deps
[
"regex"
],
deps
[
"regex"
],
deps
[
"requests"
],
deps
[
"requests"
],
deps
[
"torch"
],
deps
[
"Pillow"
],
deps
[
"Pillow"
],
]
]
setup
(
setup
(
name
=
"diffusers"
,
name
=
"diffusers"
,
version
=
"0.4.0.dev0"
,
# expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version
=
"0.4.0.dev0"
,
# expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description
=
"Diffusers"
,
description
=
"Diffusers"
,
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
...
@@ -198,7 +208,7 @@ setup(
...
@@ -198,7 +208,7 @@ setup(
package_dir
=
{
""
:
"src"
},
package_dir
=
{
""
:
"src"
},
packages
=
find_packages
(
"src"
),
packages
=
find_packages
(
"src"
),
include_package_data
=
True
,
include_package_data
=
True
,
python_requires
=
">=3.
6
.0"
,
python_requires
=
">=3.
7
.0"
,
install_requires
=
install_requires
,
install_requires
=
install_requires
,
extras_require
=
extras
,
extras_require
=
extras
,
entry_points
=
{
"console_scripts"
:
[
"diffusers-cli=diffusers.commands.diffusers_cli:main"
]},
entry_points
=
{
"console_scripts"
:
[
"diffusers-cli=diffusers.commands.diffusers_cli:main"
]},
...
...
src/diffusers/__init__.py
View file @
721e0174
...
@@ -2,6 +2,7 @@ from .utils import (
...
@@ -2,6 +2,7 @@ from .utils import (
is_inflect_available
,
is_inflect_available
,
is_onnx_available
,
is_onnx_available
,
is_scipy_available
,
is_scipy_available
,
is_torch_available
,
is_transformers_available
,
is_transformers_available
,
is_unidecode_available
,
is_unidecode_available
,
)
)
...
@@ -10,40 +11,42 @@ from .utils import (
...
@@ -10,40 +11,42 @@ from .utils import (
__version__
=
"0.4.0.dev0"
__version__
=
"0.4.0.dev0"
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
AutoencoderKL
,
UNet2DConditionModel
,
UNet2DModel
,
VQModel
from
.onnx_utils
import
OnnxRuntimeModel
from
.onnx_utils
import
OnnxRuntimeModel
from
.optimization
import
(
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
get_cosine_with_hard_restarts_schedule_with_warmup
,
get_linear_schedule_with_warmup
,
get_polynomial_decay_schedule_with_warmup
,
get_scheduler
,
)
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIMPipeline
,
DDPMPipeline
,
KarrasVePipeline
,
LDMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
from
.schedulers
import
(
DDIMScheduler
,
DDPMScheduler
,
KarrasVeScheduler
,
PNDMScheduler
,
SchedulerMixin
,
ScoreSdeVeScheduler
,
)
from
.utils
import
logging
from
.utils
import
logging
if
is_scipy_available
():
if
is_torch_available
():
from
.schedulers
import
LMSDiscreteScheduler
from
.modeling_utils
import
ModelMixin
from
.models
import
AutoencoderKL
,
UNet2DConditionModel
,
UNet2DModel
,
VQModel
from
.optimization
import
(
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
get_cosine_with_hard_restarts_schedule_with_warmup
,
get_linear_schedule_with_warmup
,
get_polynomial_decay_schedule_with_warmup
,
get_scheduler
,
)
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIMPipeline
,
DDPMPipeline
,
KarrasVePipeline
,
LDMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
from
.schedulers
import
(
DDIMScheduler
,
DDPMScheduler
,
KarrasVeScheduler
,
PNDMScheduler
,
SchedulerMixin
,
ScoreSdeVeScheduler
,
)
from
.training_utils
import
EMAModel
else
:
else
:
from
.utils.dummy_scipy_objects
import
*
# noqa F403
from
.utils.dummy_pt_objects
import
*
# noqa F403
from
.training_utils
import
EMAModel
if
is_torch_available
()
and
is_scipy_available
():
from
.schedulers
import
LMSDiscreteScheduler
else
:
from
.utils.dummy_torch_and_scipy_objects
import
*
# noqa F403
if
is_transformers_available
():
if
is_torch_available
()
and
is_transformers_available
():
from
.pipelines
import
(
from
.pipelines
import
(
LDMTextToImagePipeline
,
LDMTextToImagePipeline
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionImg2ImgPipeline
,
...
@@ -51,10 +54,9 @@ if is_transformers_available():
...
@@ -51,10 +54,9 @@ if is_transformers_available():
StableDiffusionPipeline
,
StableDiffusionPipeline
,
)
)
else
:
else
:
from
.utils.dummy_transformers_objects
import
*
# noqa F403
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
if
is_transformers_available
()
and
is_onnx_available
():
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
from
.pipelines
import
StableDiffusionOnnxPipeline
from
.pipelines
import
StableDiffusionOnnxPipeline
else
:
else
:
from
.utils.dummy_transformers_and_onnx_objects
import
*
# noqa F403
from
.utils.dummy_
torch_and_
transformers_and_onnx_objects
import
*
# noqa F403
src/diffusers/dependency_versions_table.py
View file @
721e0174
...
@@ -8,10 +8,13 @@ deps = {
...
@@ -8,10 +8,13 @@ deps = {
"datasets"
:
"datasets"
,
"datasets"
:
"datasets"
,
"filelock"
:
"filelock"
,
"filelock"
:
"filelock"
,
"flake8"
:
"flake8>=3.8.3"
,
"flake8"
:
"flake8>=3.8.3"
,
"flax"
:
"flax>=0.4.1"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"huggingface-hub"
:
"huggingface-hub>=0.8.1"
,
"huggingface-hub"
:
"huggingface-hub>=0.8.1"
,
"importlib_metadata"
:
"importlib_metadata"
,
"importlib_metadata"
:
"importlib_metadata"
,
"isort"
:
"isort>=5.5.4"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.2.8,!=0.3.2,<=0.3.6"
,
"jaxlib"
:
"jaxlib>=0.1.65,<=0.3.6"
,
"modelcards"
:
"modelcards==0.1.4"
,
"modelcards"
:
"modelcards==0.1.4"
,
"numpy"
:
"numpy"
,
"numpy"
:
"numpy"
,
"pytest"
:
"pytest"
,
"pytest"
:
"pytest"
,
...
...
src/diffusers/utils/dummy_pt_objects.py
0 → 100644
View file @
721e0174
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
ModelMixin
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
AutoencoderKL
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
UNet2DConditionModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
UNet2DModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
VQModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
def
get_constant_schedule
(
*
args
,
**
kwargs
):
requires_backends
(
get_constant_schedule
,
[
"torch"
])
def
get_constant_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_constant_schedule_with_warmup
,
[
"torch"
])
def
get_cosine_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_cosine_schedule_with_warmup
,
[
"torch"
])
def
get_cosine_with_hard_restarts_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_cosine_with_hard_restarts_schedule_with_warmup
,
[
"torch"
])
def
get_linear_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_linear_schedule_with_warmup
,
[
"torch"
])
def
get_polynomial_decay_schedule_with_warmup
(
*
args
,
**
kwargs
):
requires_backends
(
get_polynomial_decay_schedule_with_warmup
,
[
"torch"
])
def
get_scheduler
(
*
args
,
**
kwargs
):
requires_backends
(
get_scheduler
,
[
"torch"
])
class
DiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DDIMPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DDPMPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
KarrasVePipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
LDMPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
PNDMPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
ScoreSdeVePipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DDIMScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
DDPMScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
KarrasVeScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
PNDMScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
SchedulerMixin
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
ScoreSdeVeScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
EMAModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
src/diffusers/utils/dummy_scipy_objects.py
→
src/diffusers/utils/dummy_
torch_and_
scipy_objects.py
View file @
721e0174
...
@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
...
@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
class
LMSDiscreteScheduler
(
metaclass
=
DummyObject
):
class
LMSDiscreteScheduler
(
metaclass
=
DummyObject
):
_backends
=
[
"scipy"
]
_backends
=
[
"torch"
,
"scipy"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"scipy"
])
requires_backends
(
self
,
[
"torch"
,
"scipy"
])
src/diffusers/utils/dummy_transformers_and_onnx_objects.py
→
src/diffusers/utils/dummy_
torch_and_
transformers_and_onnx_objects.py
View file @
721e0174
...
@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
...
@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
class
StableDiffusionOnnxPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionOnnxPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
,
"onnx"
]
_backends
=
[
"torch"
,
"transformers"
,
"onnx"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
,
"onnx"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
,
"onnx"
])
src/diffusers/utils/dummy_transformers_objects.py
→
src/diffusers/utils/dummy_
torch_and_
transformers_objects.py
View file @
721e0174
...
@@ -5,28 +5,28 @@ from ..utils import DummyObject, requires_backends
...
@@ -5,28 +5,28 @@ from ..utils import DummyObject, requires_backends
class
LDMTextToImagePipeline
(
metaclass
=
DummyObject
):
class
LDMTextToImagePipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"torch"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
])
class
StableDiffusionImg2ImgPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionImg2ImgPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"torch"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
])
class
StableDiffusionInpaintPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionInpaintPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"torch"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
])
class
StableDiffusionPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
_backends
=
[
"torch"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
requires_backends
(
self
,
[
"torch"
,
"transformers"
])
src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
deleted
100644 → 0
View file @
f4781a0b
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
GradTTSPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
,
"inflect"
,
"unidecode"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
,
"inflect"
,
"unidecode"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment