Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
728a3f3e
Unverified
Commit
728a3f3e
authored
Oct 18, 2022
by
Anton Lozhkov
Committed by
GitHub
Oct 18, 2022
Browse files
Rename StableDiffusionOnnxPipeline -> OnnxStableDiffusionPipeline (#887)
Rename and deprecate
parent
100e094c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
47 additions
and
8 deletions
+47
-8
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-1
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
...elines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+26
-2
src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
...rs/utils/dummy_torch_and_transformers_and_onnx_objects.py
+15
-0
tests/test_pipelines.py
tests/test_pipelines.py
+3
-3
No files found.
src/diffusers/__init__.py
View file @
728a3f3e
...
@@ -58,7 +58,7 @@ else:
...
@@ -58,7 +58,7 @@ else:
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
from
.pipelines
import
StableDiffusionOnnxPipeline
from
.pipelines
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
else
:
else
:
from
.utils.dummy_torch_and_transformers_and_onnx_objects
import
*
# noqa F403
from
.utils.dummy_torch_and_transformers_and_onnx_objects
import
*
# noqa F403
...
...
src/diffusers/pipelines/__init__.py
View file @
728a3f3e
...
@@ -20,7 +20,7 @@ if is_torch_available() and is_transformers_available():
...
@@ -20,7 +20,7 @@ if is_torch_available() and is_transformers_available():
)
)
if
is_transformers_available
()
and
is_onnx_available
():
if
is_transformers_available
()
and
is_onnx_available
():
from
.stable_diffusion
import
StableDiffusionOnnxPipeline
from
.stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
if
is_transformers_available
()
and
is_flax_available
():
if
is_transformers_available
()
and
is_flax_available
():
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
728a3f3e
...
@@ -34,7 +34,7 @@ if is_transformers_available() and is_torch_available():
...
@@ -34,7 +34,7 @@ if is_transformers_available() and is_torch_available():
from
.safety_checker
import
StableDiffusionSafetyChecker
from
.safety_checker
import
StableDiffusionSafetyChecker
if
is_transformers_available
()
and
is_onnx_available
():
if
is_transformers_available
()
and
is_onnx_available
():
from
.pipeline_stable_diffusion
_onnx
import
StableDiffusionOnnxPipeline
from
.pipeline_
onnx_
stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
if
is_transformers_available
()
and
is_flax_available
():
if
is_transformers_available
()
and
is_flax_available
():
import
flax
import
flax
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion
_onnx
.py
→
src/diffusers/pipelines/stable_diffusion/pipeline_
onnx_
stable_diffusion.py
View file @
728a3f3e
...
@@ -8,14 +8,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
...
@@ -8,14 +8,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
from
...onnx_utils
import
OnnxRuntimeModel
from
...onnx_utils
import
OnnxRuntimeModel
from
...pipeline_utils
import
DiffusionPipeline
from
...pipeline_utils
import
DiffusionPipeline
from
...schedulers
import
DDIMScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
from
...schedulers
import
DDIMScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
from
...utils
import
logging
from
...utils
import
deprecate
,
logging
from
.
import
StableDiffusionPipelineOutput
from
.
import
StableDiffusionPipelineOutput
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
StableDiffusion
Onnx
Pipeline
(
DiffusionPipeline
):
class
Onnx
StableDiffusionPipeline
(
DiffusionPipeline
):
vae_decoder
:
OnnxRuntimeModel
vae_decoder
:
OnnxRuntimeModel
text_encoder
:
OnnxRuntimeModel
text_encoder
:
OnnxRuntimeModel
tokenizer
:
CLIPTokenizer
tokenizer
:
CLIPTokenizer
...
@@ -198,3 +198,27 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
...
@@ -198,3 +198,27 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
return
(
image
,
has_nsfw_concept
)
return
(
image
,
has_nsfw_concept
)
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
class
StableDiffusionOnnxPipeline
(
OnnxStableDiffusionPipeline
):
def
__init__
(
self
,
vae_decoder
:
OnnxRuntimeModel
,
text_encoder
:
OnnxRuntimeModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
OnnxRuntimeModel
,
scheduler
:
Union
[
DDIMScheduler
,
PNDMScheduler
,
LMSDiscreteScheduler
],
safety_checker
:
OnnxRuntimeModel
,
feature_extractor
:
CLIPFeatureExtractor
,
):
deprecation_message
=
"Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
deprecate
(
"StableDiffusionOnnxPipeline"
,
"1.0.0"
,
deprecation_message
)
super
().
__init__
(
vae_decoder
=
vae_decoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
View file @
728a3f3e
...
@@ -4,6 +4,21 @@
...
@@ -4,6 +4,21 @@
from
..utils
import
DummyObject
,
requires_backends
from
..utils
import
DummyObject
,
requires_backends
class
OnnxStableDiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
,
"transformers"
,
"onnx"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
,
"transformers"
,
"onnx"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"onnx"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"onnx"
])
class
StableDiffusionOnnxPipeline
(
metaclass
=
DummyObject
):
class
StableDiffusionOnnxPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
,
"transformers"
,
"onnx"
]
_backends
=
[
"torch"
,
"transformers"
,
"onnx"
]
...
...
tests/test_pipelines.py
View file @
728a3f3e
...
@@ -37,13 +37,13 @@ from diffusers import (
...
@@ -37,13 +37,13 @@ from diffusers import (
LDMPipeline
,
LDMPipeline
,
LDMTextToImagePipeline
,
LDMTextToImagePipeline
,
LMSDiscreteScheduler
,
LMSDiscreteScheduler
,
OnnxStableDiffusionPipeline
,
PNDMPipeline
,
PNDMPipeline
,
PNDMScheduler
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
ScoreSdeVeScheduler
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionInpaintPipeline
,
StableDiffusionInpaintPipeline
,
StableDiffusionOnnxPipeline
,
StableDiffusionPipeline
,
StableDiffusionPipeline
,
UNet2DConditionModel
,
UNet2DConditionModel
,
UNet2DModel
,
UNet2DModel
,
...
@@ -2010,7 +2010,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -2010,7 +2010,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_stable_diffusion_onnx
(
self
):
def
test_stable_diffusion_onnx
(
self
):
sd_pipe
=
StableDiffusion
Onnx
Pipeline
.
from_pretrained
(
sd_pipe
=
Onnx
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"onnx"
,
provider
=
"CPUExecutionProvider"
"CompVis/stable-diffusion-v1-4"
,
revision
=
"onnx"
,
provider
=
"CPUExecutionProvider"
)
)
...
@@ -2214,7 +2214,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -2214,7 +2214,7 @@ class PipelineTesterMixin(unittest.TestCase):
test_callback_fn
.
has_been_called
=
False
test_callback_fn
.
has_been_called
=
False
pipe
=
StableDiffusion
Onnx
Pipeline
.
from_pretrained
(
pipe
=
Onnx
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"onnx"
,
provider
=
"CPUExecutionProvider"
"CompVis/stable-diffusion-v1-4"
,
revision
=
"onnx"
,
provider
=
"CPUExecutionProvider"
)
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment