Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
2e209c30
Unverified
Commit
2e209c30
authored
Oct 06, 2022
by
Anton Lozhkov
Committed by
GitHub
Oct 06, 2022
Browse files
[v0.4.0] Temporarily remove Flax modules from the public API (#755)
Temporarily remove Flax modules from the public API
parent
9c9462f3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
6 additions
and
100 deletions
+6
-100
docs/source/api/models.mdx
docs/source/api/models.mdx
+0
-18
docs/source/api/schedulers.mdx
docs/source/api/schedulers.mdx
+1
-1
setup.py
setup.py
+2
-10
src/diffusers/__init__.py
src/diffusers/__init__.py
+0
-23
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+0
-3
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-5
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+0
-3
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+1
-25
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-12
No files found.
docs/source/api/models.mdx
View file @
2e209c30
...
...
@@ -45,21 +45,3 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## AutoencoderKL
[[autodoc]] AutoencoderKL
## FlaxModelMixin
[[autodoc]] FlaxModelMixin
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel
## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
docs/source/api/schedulers.mdx
View file @
2e209c30
...
...
@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
To this end, the design of schedulers is such that:
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch
, but are designed to be framework independent (partial Jax support currently exists)
.
- Schedulers are currently by default in PyTorch.
## API
...
...
setup.py
View file @
2e209c30
...
...
@@ -84,13 +84,10 @@ _deps = [
"datasets"
,
"filelock"
,
"flake8>=3.8.3"
,
"flax>=0.4.1"
,
"hf-doc-builder>=0.3.0"
,
"huggingface-hub>=0.10.0"
,
"importlib_metadata"
,
"isort>=5.5.4"
,
"jax>=0.2.8,!=0.3.2,<=0.3.6"
,
"jaxlib>=0.1.65,<=0.3.6"
,
"modelcards>=0.1.4"
,
"numpy"
,
"onnxruntime"
,
...
...
@@ -188,15 +185,9 @@ extras["test"] = deps_list(
"torchvision"
,
"transformers"
)
extras
[
"torch"
]
=
deps_list
(
"torch"
)
if
os
.
name
==
"nt"
:
# windows
extras
[
"flax"
]
=
[]
# jax is not supported on windows
else
:
extras
[
"flax"
]
=
deps_list
(
"jax"
,
"jaxlib"
,
"flax"
)
extras
[
"dev"
]
=
(
extras
[
"quality"
]
+
extras
[
"test"
]
+
extras
[
"training"
]
+
extras
[
"docs"
]
+
extras
[
"torch"
]
+
extras
[
"flax"
]
extras
[
"quality"
]
+
extras
[
"test"
]
+
extras
[
"training"
]
+
extras
[
"docs"
]
)
install_requires
=
[
...
...
@@ -207,6 +198,7 @@ install_requires = [
deps
[
"regex"
],
deps
[
"requests"
],
deps
[
"Pillow"
],
deps
[
"torch"
]
]
setup
(
...
...
src/diffusers/__init__.py
View file @
2e209c30
from
.utils
import
(
is_flax_available
,
is_inflect_available
,
is_onnx_available
,
is_scipy_available
,
...
...
@@ -61,25 +60,3 @@ if is_torch_available() and is_transformers_available() and is_onnx_available():
from
.pipelines
import
StableDiffusionOnnxPipeline
else
:
from
.utils.dummy_torch_and_transformers_and_onnx_objects
import
*
# noqa F403
if
is_flax_available
():
from
.modeling_flax_utils
import
FlaxModelMixin
from
.models.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.models.vae_flax
import
FlaxAutoencoderKL
from
.pipeline_flax_utils
import
FlaxDiffusionPipeline
from
.schedulers
import
(
FlaxDDIMScheduler
,
FlaxDDPMScheduler
,
FlaxKarrasVeScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
,
FlaxSchedulerMixin
,
FlaxScoreSdeVeScheduler
,
)
else
:
from
.utils.dummy_flax_objects
import
*
# noqa F403
if
is_flax_available
()
and
is_transformers_available
():
from
.pipelines
import
FlaxStableDiffusionPipeline
else
:
from
.utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
src/diffusers/dependency_versions_table.py
View file @
2e209c30
...
...
@@ -8,13 +8,10 @@ deps = {
"datasets"
:
"datasets"
,
"filelock"
:
"filelock"
,
"flake8"
:
"flake8>=3.8.3"
,
"flax"
:
"flax>=0.4.1"
,
"hf-doc-builder"
:
"hf-doc-builder>=0.3.0"
,
"huggingface-hub"
:
"huggingface-hub>=0.10.0"
,
"importlib_metadata"
:
"importlib_metadata"
,
"isort"
:
"isort>=5.5.4"
,
"jax"
:
"jax>=0.2.8,!=0.3.2,<=0.3.6"
,
"jaxlib"
:
"jaxlib>=0.1.65,<=0.3.6"
,
"modelcards"
:
"modelcards>=0.1.4"
,
"numpy"
:
"numpy"
,
"onnxruntime"
:
"onnxruntime"
,
...
...
src/diffusers/models/__init__.py
View file @
2e209c30
...
...
@@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
..utils
import
is_flax_available
,
is_torch_available
from
..utils
import
is_torch_available
if
is_torch_available
():
from
.unet_2d
import
UNet2DModel
from
.unet_2d_condition
import
UNet2DConditionModel
from
.vae
import
AutoencoderKL
,
VQModel
if
is_flax_available
():
from
.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.vae_flax
import
FlaxAutoencoderKL
src/diffusers/pipelines/__init__.py
View file @
2e209c30
...
...
@@ -21,6 +21,3 @@ if is_torch_available() and is_transformers_available():
if
is_transformers_available
()
and
is_onnx_available
():
from
.stable_diffusion
import
StableDiffusionOnnxPipeline
if
is_transformers_available
()
and
is_flax_available
():
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
2e209c30
...
...
@@ -6,7 +6,7 @@ import numpy as np
import
PIL
from
PIL
import
Image
from
...utils
import
BaseOutput
,
is_flax_available
,
is_onnx_available
,
is_torch_available
,
is_transformers_available
from
...utils
import
BaseOutput
,
is_onnx_available
,
is_torch_available
,
is_transformers_available
@
dataclass
...
...
@@ -35,27 +35,3 @@ if is_transformers_available() and is_torch_available():
if
is_transformers_available
()
and
is_onnx_available
():
from
.pipeline_stable_diffusion_onnx
import
StableDiffusionOnnxPipeline
if
is_transformers_available
()
and
is_flax_available
():
import
flax
@
flax
.
struct
.
dataclass
class
FlaxStableDiffusionPipelineOutput
(
BaseOutput
):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images
:
Union
[
List
[
PIL
.
Image
.
Image
],
np
.
ndarray
]
nsfw_content_detected
:
List
[
bool
]
from
...schedulers.scheduling_pndm_flax
import
PNDMSchedulerState
from
.pipeline_flax_stable_diffusion
import
FlaxStableDiffusionPipeline
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
src/diffusers/schedulers/__init__.py
View file @
2e209c30
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
..utils
import
is_flax_available
,
is_scipy_available
,
is_torch_available
from
..utils
import
is_scipy_available
,
is_torch_available
if
is_torch_available
():
...
...
@@ -27,17 +27,6 @@ if is_torch_available():
else
:
from
..utils.dummy_pt_objects
import
*
# noqa F403
if
is_flax_available
():
from
.scheduling_ddim_flax
import
FlaxDDIMScheduler
from
.scheduling_ddpm_flax
import
FlaxDDPMScheduler
from
.scheduling_karras_ve_flax
import
FlaxKarrasVeScheduler
from
.scheduling_lms_discrete_flax
import
FlaxLMSDiscreteScheduler
from
.scheduling_pndm_flax
import
FlaxPNDMScheduler
from
.scheduling_sde_ve_flax
import
FlaxScoreSdeVeScheduler
from
.scheduling_utils_flax
import
FlaxSchedulerMixin
else
:
from
..utils.dummy_flax_objects
import
*
# noqa F403
if
is_scipy_available
()
and
is_torch_available
():
from
.scheduling_lms_discrete
import
LMSDiscreteScheduler
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment