Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
970e3060
Commit
970e3060
authored
Oct 06, 2022
by
anton-l
Browse files
Revert "[v0.4.0] Temporarily remove Flax modules from the public API (#755)"
This reverts commit
2e209c30
.
parent
c15cda03
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
100 additions
and
6 deletions
+100
-6
docs/source/api/models.mdx
docs/source/api/models.mdx
+18
-0
docs/source/api/schedulers.mdx
docs/source/api/schedulers.mdx
+1
-1
setup.py
setup.py
+10
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+23
-0
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+3
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+5
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+3
-0
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+25
-1
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+12
-1
No files found.
docs/source/api/models.mdx
View file @
970e3060
...
@@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
...
@@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## AutoencoderKL
## AutoencoderKL
[[autodoc]] AutoencoderKL
[[autodoc]] AutoencoderKL
## FlaxModelMixin
[[autodoc]] FlaxModelMixin
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel
## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
docs/source/api/schedulers.mdx
View file @
970e3060
...
@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
...
@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
To this end, the design of schedulers is such that:
To this end, the design of schedulers is such that:
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch.
- Schedulers are currently by default in PyTorch
, but are designed to be framework independent (partial Jax support currently exists)
.
## API
## API
...
...
setup.py
View file @
970e3060
...
@@ -84,10 +84,13 @@ _deps = [
...
@@ -84,10 +84,13 @@ _deps = [
"datasets"
,
"datasets"
,
"filelock"
,
"filelock"
,
"flake8>=3.8.3"
,
"flake8>=3.8.3"
,
"flax>=0.4.1"
,
"hf-doc-builder>=0.3.0"
,
"hf-doc-builder>=0.3.0"
,
"huggingface-hub>=0.10.0"
,
"huggingface-hub>=0.10.0"
,
"importlib_metadata"
,
"importlib_metadata"
,
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"jax>=0.2.8,!=0.3.2,<=0.3.6"
,
"jaxlib>=0.1.65,<=0.3.6"
,
"modelcards>=0.1.4"
,
"modelcards>=0.1.4"
,
"numpy"
,
"numpy"
,
"onnxruntime"
,
"onnxruntime"
,
...
@@ -185,9 +188,15 @@ extras["test"] = deps_list(
...
@@ -185,9 +188,15 @@ extras["test"] = deps_list(
"torchvision"
,
"torchvision"
,
"transformers"
"transformers"
)
)
extras
[
"torch"
]
=
deps_list
(
"torch"
)
if
os
.
name
==
"nt"
:
# windows
extras
[
"flax"
]
=
[]
# jax is not supported on windows
else
:
extras
[
"flax"
]
=
deps_list
(
"jax"
,
"jaxlib"
,
"flax"
)
extras
[
"dev"
]
=
(
extras
[
"dev"
]
=
(
extras
[
"quality"
]
+
extras
[
"test"
]
+
extras
[
"training"
]
+
extras
[
"docs"
]
extras
[
"quality"
]
+
extras
[
"test"
]
+
extras
[
"training"
]
+
extras
[
"docs"
]
+
extras
[
"torch"
]
+
extras
[
"flax"
]
)
)
install_requires
=
[
install_requires
=
[
...
@@ -198,7 +207,6 @@ install_requires = [
...
@@ -198,7 +207,6 @@ install_requires = [
deps
[
"regex"
],
deps
[
"regex"
],
deps
[
"requests"
],
deps
[
"requests"
],
deps
[
"Pillow"
],
deps
[
"Pillow"
],
deps
[
"torch"
]
]
]
setup
(
setup
(
...
...
src/diffusers/__init__.py
View file @
970e3060
from
.utils
import
(
from
.utils
import
(
is_flax_available
,
is_inflect_available
,
is_inflect_available
,
is_onnx_available
,
is_onnx_available
,
is_scipy_available
,
is_scipy_available
,
...
@@ -60,3 +61,25 @@ if is_torch_available() and is_transformers_available() and is_onnx_available():
...
@@ -60,3 +61,25 @@ if is_torch_available() and is_transformers_available() and is_onnx_available():
from
.pipelines
import
StableDiffusionOnnxPipeline
from
.pipelines
import
StableDiffusionOnnxPipeline
else
:
else
:
from
.utils.dummy_torch_and_transformers_and_onnx_objects
import
*
# noqa F403
from
.utils.dummy_torch_and_transformers_and_onnx_objects
import
*
# noqa F403
if
is_flax_available
():
from
.modeling_flax_utils
import
FlaxModelMixin
from
.models.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.models.vae_flax
import
FlaxAutoencoderKL
from
.pipeline_flax_utils
import
FlaxDiffusionPipeline
from
.schedulers
import
(
FlaxDDIMScheduler
,
FlaxDDPMScheduler
,
FlaxKarrasVeScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
,
FlaxSchedulerMixin
,
FlaxScoreSdeVeScheduler
,
)
else
:
from
.utils.dummy_flax_objects
import
*
# noqa F403
if
is_flax_available
()
and
is_transformers_available
():
from
.pipelines
import
FlaxStableDiffusionPipeline
else
:
from
.utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
src/diffusers/dependency_versions_table.py
View file @
970e3060
...
@@ -8,10 +8,13 @@ deps = {
...
@@ -8,10 +8,13 @@ deps = {
"datasets"
:
"datasets"
,
"datasets"
:
"datasets"
,
"filelock"
:
"filelock"
,
"filelock"
:
"filelock"
,
"flake8"
:
"flake8>=3.8.3"
,
"flake8"
:
"flake8>=3.8.3"
,
"flax"
:
"flax>=0.4.1"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"huggingface-hub"
:
"huggingface-hub>=0.10.0"
,
"huggingface-hub"
:
"huggingface-hub>=0.10.0"
,
"importlib_metadata"
:
"importlib_metadata"
,
"importlib_metadata"
:
"importlib_metadata"
,
"isort"
:
"isort>=5.5.4"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.2.8,!=0.3.2,<=0.3.6"
,
"jaxlib"
:
"jaxlib>=0.1.65,<=0.3.6"
,
"modelcards"
:
"modelcards>=0.1.4"
,
"modelcards"
:
"modelcards>=0.1.4"
,
"numpy"
:
"numpy"
,
"numpy"
:
"numpy"
,
"onnxruntime"
:
"onnxruntime"
,
"onnxruntime"
:
"onnxruntime"
,
...
...
src/diffusers/models/__init__.py
View file @
970e3060
...
@@ -12,10 +12,14 @@
...
@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
..utils
import
is_torch_available
from
..utils
import
is_flax_available
,
is_torch_available
if
is_torch_available
():
if
is_torch_available
():
from
.unet_2d
import
UNet2DModel
from
.unet_2d
import
UNet2DModel
from
.unet_2d_condition
import
UNet2DConditionModel
from
.unet_2d_condition
import
UNet2DConditionModel
from
.vae
import
AutoencoderKL
,
VQModel
from
.vae
import
AutoencoderKL
,
VQModel
if
is_flax_available
():
from
.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.vae_flax
import
FlaxAutoencoderKL
src/diffusers/pipelines/__init__.py
View file @
970e3060
...
@@ -21,3 +21,6 @@ if is_torch_available() and is_transformers_available():
...
@@ -21,3 +21,6 @@ if is_torch_available() and is_transformers_available():
if
is_transformers_available
()
and
is_onnx_available
():
if
is_transformers_available
()
and
is_onnx_available
():
from
.stable_diffusion
import
StableDiffusionOnnxPipeline
from
.stable_diffusion
import
StableDiffusionOnnxPipeline
if
is_transformers_available
()
and
is_flax_available
():
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
970e3060
...
@@ -6,7 +6,7 @@ import numpy as np
...
@@ -6,7 +6,7 @@ import numpy as np
import
PIL
import
PIL
from
PIL
import
Image
from
PIL
import
Image
from
...utils
import
BaseOutput
,
is_onnx_available
,
is_torch_available
,
is_transformers_available
from
...utils
import
BaseOutput
,
is_flax_available
,
is_onnx_available
,
is_torch_available
,
is_transformers_available
@
dataclass
@
dataclass
...
@@ -35,3 +35,27 @@ if is_transformers_available() and is_torch_available():
...
@@ -35,3 +35,27 @@ if is_transformers_available() and is_torch_available():
if
is_transformers_available
()
and
is_onnx_available
():
if
is_transformers_available
()
and
is_onnx_available
():
from
.pipeline_stable_diffusion_onnx
import
StableDiffusionOnnxPipeline
from
.pipeline_stable_diffusion_onnx
import
StableDiffusionOnnxPipeline
if
is_transformers_available
()
and
is_flax_available
():
import
flax
@
flax
.
struct
.
dataclass
class
FlaxStableDiffusionPipelineOutput
(
BaseOutput
):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images
:
Union
[
List
[
PIL
.
Image
.
Image
],
np
.
ndarray
]
nsfw_content_detected
:
List
[
bool
]
from
...schedulers.scheduling_pndm_flax
import
PNDMSchedulerState
from
.pipeline_flax_stable_diffusion
import
FlaxStableDiffusionPipeline
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
src/diffusers/schedulers/__init__.py
View file @
970e3060
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
..utils
import
is_scipy_available
,
is_torch_available
from
..utils
import
is_flax_available
,
is_scipy_available
,
is_torch_available
if
is_torch_available
():
if
is_torch_available
():
...
@@ -27,6 +27,17 @@ if is_torch_available():
...
@@ -27,6 +27,17 @@ if is_torch_available():
else
:
else
:
from
..utils.dummy_pt_objects
import
*
# noqa F403
from
..utils.dummy_pt_objects
import
*
# noqa F403
if
is_flax_available
():
from
.scheduling_ddim_flax
import
FlaxDDIMScheduler
from
.scheduling_ddpm_flax
import
FlaxDDPMScheduler
from
.scheduling_karras_ve_flax
import
FlaxKarrasVeScheduler
from
.scheduling_lms_discrete_flax
import
FlaxLMSDiscreteScheduler
from
.scheduling_pndm_flax
import
FlaxPNDMScheduler
from
.scheduling_sde_ve_flax
import
FlaxScoreSdeVeScheduler
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
else
:
from
..utils.dummy_flax_objects
import
*
# noqa F403
if
is_scipy_available
()
and
is_torch_available
():
if
is_scipy_available
()
and
is_torch_available
():
from
.scheduling_lms_discrete
import
LMSDiscreteScheduler
from
.scheduling_lms_discrete
import
LMSDiscreteScheduler
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment