Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
728a3f3e
Unverified
Commit
728a3f3e
authored
Oct 18, 2022
by
Anton Lozhkov
Committed by
GitHub
Oct 18, 2022
Browse files
Rename StableDiffusionOnnxPipeline -> OnnxStableDiffusionPipeline (#887)
Rename and deprecate
parent
100e094c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
47 additions
and
8 deletions
+47
-8
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-1
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
...elines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+26
-2
src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
...rs/utils/dummy_torch_and_transformers_and_onnx_objects.py
+15
-0
tests/test_pipelines.py
tests/test_pipelines.py
+3
-3
No files found.
src/diffusers/__init__.py
View file @
728a3f3e
...
...
@@ -58,7 +58,7 @@ else:
from
.utils.dummy_torch_and_transformers_objects
import
*
# noqa F403
if
is_torch_available
()
and
is_transformers_available
()
and
is_onnx_available
():
from
.pipelines
import
StableDiffusionOnnxPipeline
from
.pipelines
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
else
:
from
.utils.dummy_torch_and_transformers_and_onnx_objects
import
*
# noqa F403
...
...
src/diffusers/pipelines/__init__.py
View file @
728a3f3e
...
...
@@ -20,7 +20,7 @@ if is_torch_available() and is_transformers_available():
)
if
is_transformers_available
()
and
is_onnx_available
():
from
.stable_diffusion
import
StableDiffusionOnnxPipeline
from
.stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
if
is_transformers_available
()
and
is_flax_available
():
from
.stable_diffusion
import
FlaxStableDiffusionPipeline
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
728a3f3e
...
...
@@ -34,7 +34,7 @@ if is_transformers_available() and is_torch_available():
from
.safety_checker
import
StableDiffusionSafetyChecker
if
is_transformers_available
()
and
is_onnx_available
():
from
.pipeline_stable_diffusion
_onnx
import
StableDiffusionOnnxPipeline
from
.pipeline_
onnx_
stable_diffusion
import
OnnxStableDiffusionPipeline
,
StableDiffusionOnnxPipeline
if
is_transformers_available
()
and
is_flax_available
():
import
flax
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion
_onnx
.py
→
src/diffusers/pipelines/stable_diffusion/pipeline_
onnx_
stable_diffusion.py
View file @
728a3f3e
...
...
@@ -8,14 +8,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
from
...onnx_utils
import
OnnxRuntimeModel
from
...pipeline_utils
import
DiffusionPipeline
from
...schedulers
import
DDIMScheduler
,
LMSDiscreteScheduler
,
PNDMScheduler
from
...utils
import
logging
from
...utils
import
deprecate
,
logging
from
.
import
StableDiffusionPipelineOutput
logger
=
logging
.
get_logger
(
__name__
)
class
StableDiffusion
Onnx
Pipeline
(
DiffusionPipeline
):
class
Onnx
StableDiffusionPipeline
(
DiffusionPipeline
):
vae_decoder
:
OnnxRuntimeModel
text_encoder
:
OnnxRuntimeModel
tokenizer
:
CLIPTokenizer
...
...
@@ -198,3 +198,27 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
return
(
image
,
has_nsfw_concept
)
return
StableDiffusionPipelineOutput
(
images
=
image
,
nsfw_content_detected
=
has_nsfw_concept
)
class
StableDiffusionOnnxPipeline
(
OnnxStableDiffusionPipeline
):
def
__init__
(
self
,
vae_decoder
:
OnnxRuntimeModel
,
text_encoder
:
OnnxRuntimeModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
OnnxRuntimeModel
,
scheduler
:
Union
[
DDIMScheduler
,
PNDMScheduler
,
LMSDiscreteScheduler
],
safety_checker
:
OnnxRuntimeModel
,
feature_extractor
:
CLIPFeatureExtractor
,
):
deprecation_message
=
"Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
deprecate
(
"StableDiffusionOnnxPipeline"
,
"1.0.0"
,
deprecation_message
)
super
().
__init__
(
vae_decoder
=
vae_decoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py
View file @
728a3f3e
...
...
@@ -4,6 +4,21 @@
from
..utils
import
DummyObject
,
requires_backends
class
OnnxStableDiffusionPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
,
"transformers"
,
"onnx"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
,
"transformers"
,
"onnx"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"onnx"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
,
"transformers"
,
"onnx"
])
class
StableDiffusionOnnxPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
,
"transformers"
,
"onnx"
]
...
...
tests/test_pipelines.py
View file @
728a3f3e
...
...
@@ -37,13 +37,13 @@ from diffusers import (
LDMPipeline
,
LDMTextToImagePipeline
,
LMSDiscreteScheduler
,
OnnxStableDiffusionPipeline
,
PNDMPipeline
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionInpaintPipeline
,
StableDiffusionOnnxPipeline
,
StableDiffusionPipeline
,
UNet2DConditionModel
,
UNet2DModel
,
...
...
@@ -2010,7 +2010,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
def
test_stable_diffusion_onnx
(
self
):
sd_pipe
=
StableDiffusion
Onnx
Pipeline
.
from_pretrained
(
sd_pipe
=
Onnx
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"onnx"
,
provider
=
"CPUExecutionProvider"
)
...
...
@@ -2214,7 +2214,7 @@ class PipelineTesterMixin(unittest.TestCase):
test_callback_fn
.
has_been_called
=
False
pipe
=
StableDiffusion
Onnx
Pipeline
.
from_pretrained
(
pipe
=
Onnx
StableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"onnx"
,
provider
=
"CPUExecutionProvider"
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment